xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/simd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import collections
5import contextlib
6import dataclasses
7import functools
8import itertools
9import logging
10import math
11import operator
12from typing import (
13    Any,
14    Callable,
15    Counter,
16    DefaultDict,
17    Dict,
18    Iterable,
19    List,
20    Optional,
21    Sequence,
22    Tuple,
23    Union,
24)
25
26import sympy
27
28import torch
29import torch._logging
30from torch.utils._ordered_set import OrderedSet
31from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing
32from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
33
34from ..._dynamo.utils import counters
35from .. import config, ir, scheduler
36from ..codecache import code_hash
37from ..dependencies import Dep, MemoryDep, StarDep, WeakDep
38from ..ir import IRNode, TritonTemplateBuffer
39from ..optimize_indexing import indexing_dtype_strength_reduction
40from ..runtime.hints import ReductionHint
41from ..runtime.runtime_utils import green_text, yellow_text
42from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse
43from ..utils import (
44    get_dtype_size,
45    IndentedBuffer,
46    Placeholder,
47    sympy_index_symbol,
48    sympy_product,
49    sympy_subs,
50    unique,
51)
52from ..virtualized import ops, OpsWrapper, V
53from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter
54from .multi_kernel import MultiKernel
55
56
57log = logging.getLogger(__name__)
58perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
59schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
60fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
61
62
63pexpr = PythonPrinter().doprint
64
65
66@dataclasses.dataclass
67class IterationRanges:
68    """
69    Each range tree represents multiple sets of iteration indexing
70    in a single tiled dimension in the output kernel.
71
72    If you have two loops ranges one (4, 3, 2) and another (4, 6),
73    then the range tree will be:
74            4 (i0)
75        3 (i1)  6 (i3)
76        2 (i2)
77    Where i0 is shared between both loops, but then the split into
78    different indexing vars.  All loop ranges must iterate over
79    the same number of elements.
80    """
81
82    def __init__(
83        self,
84        name: str,
85        var_list: List[sympy.Symbol],
86        var_ranges: Dict[sympy.Symbol, sympy.Expr],
87        numel: sympy.Expr,
88        prefix: str,
89        *,
90        kernel: SIMDKernel,
91        divisor=sympy.Integer(1),
92        length=sympy.Integer(1),
93        root: IterationRangesRoot,
94    ) -> None:
95        super().__init__()
96        self.name = name
97        self.var_list = var_list
98        self.var_ranges = var_ranges
99        self.numel = numel
100        self.prefix = prefix
101        self.divisor = divisor
102        self.length = length
103        self.kernel = kernel
104        self.root = root
105
106    def symbol(self):
107        return sympy_index_symbol(self.name)
108
109
110class IterationRangesRoot(IterationRanges):
111    def __init__(
112        self,
113        name: str,
114        numel: sympy.Expr,
115        # TODO: this is probably SymTy.INDEX and SymTy.RINDEX
116        prefix: str,
117        index: int,
118        kernel: SIMDKernel,
119        pid_cache=None,
120        *,
121        is_loop: bool,
122        tensor_dim: Optional[int],
123        grid_dim: Optional[int],
124        has_zdim: bool,
125    ) -> None:
126        if pid_cache is None:
127            pid_cache = {}
128        super().__init__(
129            name=name,
130            var_list=[],
131            var_ranges={},
132            numel=numel,
133            prefix=prefix,
134            kernel=kernel,
135            root=self,
136        )
137        self.index = index
138        # Store all the nodes in one flat list
139        self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {}
140        # This is for re-ordering program ID in triton mm template
141        # pid_cache["tl.program_id(0)"] = pid_m
142        self.pid_cache: Dict[str, str] = pid_cache
143
144        # True if the dimension is implemented as a single program looping over
145        # the full dimension (currently only used for non-persistent reduction)
146        assert not is_loop or (prefix == "r" and grid_dim is None)
147        self.is_loop = is_loop
148        # Index of corresponding dimension on triton tensors
149        self.tensor_dim = tensor_dim
150        # Index of corresponding dimension in the triton grid
151        self.grid_dim = grid_dim
152        self.has_zdim = has_zdim
153
154    def __repr__(self) -> str:
155        return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)"
156
157    def cache_clear(self):
158        for node in self.nodes.values():
159            node.cache_clear()
160
161    def index_sym(self):
162        return sympy_index_symbol(f"{self.prefix}index")
163
164    def lookup(self, divisor, length):
165        """
166        Lookup a given RangeTreeEntry, creating it if needed
167        """
168        if V.graph.sizevars.statically_known_equals(divisor * length, self.numel):
169            expr = FloorDiv(self.index_sym(), divisor)
170        else:
171            expr = ModularIndexing(self.index_sym(), divisor, length)
172
173        if expr not in self.nodes:
174            node = IterationRangesEntry(
175                f"{self.prefix}{next(V.kernel.iter_vars_count)}",
176                divisor,
177                length,
178                expr,
179                self,
180            )
181            V.kernel.range_tree_nodes[node.symbol()] = node
182            self.var_list.append(node.symbol())
183            self.var_ranges[node.symbol()] = length
184            self.nodes[expr] = node
185        return self.nodes[expr]
186
187    def construct_entries(self, lengths: List[sympy.Expr]):
188        divisor = sympy.Integer(1)
189        itervars = []
190        for length in reversed(lengths):
191            itervars.append(self.lookup(divisor, length))
192            divisor = divisor * length
193        return list(reversed(itervars))
194
195    def construct(self, lengths: List[sympy.Expr]):
196        return [e.symbol() for e in self.construct_entries(lengths)]
197
198    def vars_and_sizes(self, index: sympy.Expr):
199        """Figure out vars from this tree used in index"""
200        nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
201        nodes = [n for n in nodes if n and n.prefix == self.prefix]
202        nodes.sort(
203            key=lambda x: V.graph.sizevars.size_hint(
204                x.divisor, fallback=config.unbacked_symint_fallback
205            )
206        )
207        divisor = sympy.Integer(1)
208        index_vars = []
209        sizes = []
210
211        def add(node):
212            nonlocal divisor
213            index_vars.append(node.symbol())
214            sizes.append(node.length)
215            divisor = divisor * node.length
216
217        for node in nodes:
218            if not V.graph.sizevars.statically_known_equals(node.divisor, divisor):
219                # fill in unused index var
220                add(self.lookup(divisor, FloorDiv(node.divisor, divisor)))
221                divisor = node.divisor
222            add(node)
223        if not V.graph.sizevars.statically_known_equals(self.numel, divisor):
224            # fill in unused index var
225            add(self.lookup(divisor, FloorDiv(self.numel, divisor)))
226
227        return list(reversed(index_vars)), list(reversed(sizes))
228
229
230class IterationRangesEntry(IterationRanges):
231    def __init__(
232        self,
233        name: str,
234        divisor: sympy.Expr,
235        length: sympy.Expr,
236        expr: sympy.Expr,
237        parent: IterationRanges,
238    ) -> None:
239        super().__init__(
240            name=name,
241            numel=parent.numel / length,
242            var_list=parent.var_list,
243            var_ranges=parent.var_ranges,
244            prefix=parent.prefix,
245            divisor=divisor,
246            length=length,
247            kernel=parent.kernel,
248            root=parent.root,
249        )
250        self.parent = parent
251        self.codegen = functools.lru_cache(None)(self._codegen)
252        self.expr = expr
253
254    def __repr__(self) -> str:
255        return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})"
256
257    def set_name(self, name):
258        self.codegen = lambda: name  # type: ignore[assignment]
259        self.codegen.cache_clear = lambda: None  # type: ignore[method-assign]
260        self.name = name
261
262    def cache_clear(self):
263        self.codegen.cache_clear()
264
265    def _codegen(self):
266        V.kernel.codegen_iteration_ranges_entry(self)
267        return self.name
268
269    def precomputed_args(self):
270        # for dynamic shapes, find parts of indexing expressions that have to be precomputed
271        precomputed_args: List[sympy.Expr] = []
272        if isinstance(self.expr, sympy.Symbol):
273            return precomputed_args
274        assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr)
275        for arg in self.expr.args[1:]:
276            if not isinstance(arg, (sympy.Integer, sympy.Symbol)):
277                symbols = arg.free_symbols
278                if len(symbols) > 0 and all(
279                    symbol_is_type(s, SymT.SIZE) for s in symbols
280                ):
281                    precomputed_args.append(arg)
282        return precomputed_args
283
284    def __hash__(self):
285        return hash(self.name)
286
287    def __eq__(self, other):
288        return self.name == other.name
289
290
291def constant_repr(value):
292    if value == float("inf"):
293        return 'float("inf")'
294    elif value == float("-inf"):
295        return 'float("-inf")'
296    elif math.isnan(value):
297        return 'float("nan")'
298    return repr(value)
299
300
301class SIMDKernel(Kernel):
302    """
303    Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests.
304    """
305
306    sexpr = pexpr
307    kexpr: Callable[[sympy.Expr], str]
308    allow_block_ptr = False
309
310    def __init__(
311        self,
312        *groups,
313        index_dtype: str,
314        mutations: Optional[OrderedSet[str]] = None,
315        pid_cache=None,
316        reduction_hint=ReductionHint.DEFAULT,
317        override_persistent_reduction=None,
318    ) -> None:
319        if pid_cache is None:
320            pid_cache = {}
321        super().__init__()
322        self.body = IndentedBuffer()
323        self.indexing_code = IndentedBuffer()
324        self.numels = [V.graph.sizevars.simplify(s) for s in groups]
325        self.mutations: OrderedSet[str] = (
326            mutations if mutations is not None else OrderedSet()
327        )
328        self.range_trees: List[IterationRangesRoot] = []
329        self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {}
330        self.iter_vars_count = itertools.count()
331        self.inside_reduction = self.numels[-1] != 1
332        self.reduction_hint = reduction_hint
333        self.index_dtype: str = index_dtype
334        self.last_usage: OrderedSet[str] = OrderedSet()
335        self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list)
336        self.persistent_reduction: bool = (
337            override_persistent_reduction
338            if override_persistent_reduction is not None
339            else self.should_use_persistent_reduction()
340        )
341        self.no_x_dim = self.want_no_x_dim()
342        self.code_hash: Union[str, None] = None
343
344        # define this in a closure to make cache local to object
345        @functools.lru_cache(None)
346        def simplify_indexing(index: sympy.Expr):
347            index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges())
348            for tree in self.range_trees:
349                index = self.combine_contiguous_dims(index, tree)
350
351            return self.combine_modular_indexing_pairs(index)
352
353        self.simplify_indexing = simplify_indexing
354        self.initialize_range_tree(pid_cache)
355
356    def want_no_x_dim(self):
357        return False
358
359    def initialize_range_tree(self, pid_cache):
360        no_r_dim = not self.inside_reduction or self.numels[-1] == 1
361
362        prefixes = "zyxr"
363        active_prefixes = prefixes[-len(self.numels) :]
364
365        grid_dims = "xyz"
366        if self.no_x_dim:
367            tensor_dims = "r"
368        elif no_r_dim:
369            tensor_dims = "xyz"
370        else:
371            tensor_dims = "xyzr"
372
373        tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes)
374
375        for i, prefix in enumerate(active_prefixes):
376            is_reduction = prefix == "r"
377            tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None
378            grid_dim = None if is_reduction else grid_dims.find(prefix)
379            index = i if grid_dim is None else grid_dim
380            self.range_trees.append(
381                IterationRangesRoot(
382                    f"{prefix}index",
383                    self.numels[i],
384                    prefix,
385                    index,
386                    self,
387                    pid_cache=pid_cache,
388                    is_loop=is_reduction and not self.persistent_reduction,
389                    tensor_dim=tensor_dim,
390                    grid_dim=grid_dim,
391                    has_zdim="z" in active_prefixes,
392                )
393            )
394
395    def finalize_indexing(self, indices: Sequence[sympy.Expr]):
396        """
397        Hook called right before codegen with every index that will be
398        used in the fused kernel.
399        """
400
401    def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
402        prior = self.inside_reduction
403        self.inside_reduction = False
404        try:
405            return self.store(name, index, value)
406        finally:
407            self.inside_reduction = prior
408
409    def should_use_persistent_reduction(self) -> bool:
410        return False  # defined in subclass
411
412    def var_ranges(self):
413        return dict(
414            itertools.chain.from_iterable(
415                tree.var_ranges.items() for tree in self.range_trees
416            )
417        )
418
419    def triton_tensor_ndim(self):
420        return sum(int(tree.tensor_dim is not None) for tree in self.range_trees)
421
422    def indexing_size_str(self, i):
423        sizes = ["None"] * self.triton_tensor_ndim()
424        sizes[i] = ":"
425        return f"[{', '.join(sizes)}]"
426
427    def dense_size_list(self) -> List[str]:
428        sizes = ["1"] * self.triton_tensor_ndim()
429        for tree in self.range_trees:
430            if tree.tensor_dim is None:
431                continue
432
433            if tree.prefix != "r" or self.inside_reduction:
434                sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK"
435        return sizes
436
437    def dense_size_str(self):
438        sizes = self.dense_size_list()
439        return f"[{', '.join(sizes)}]"
440
441    def combine_modular_indexing_pairs(self, index):
442        if not isinstance(index, ModularIndexing):
443            return index
444        x = index.args[0]
445        if (tree_node := self.range_tree_nodes.get(x)) is None:
446            return index
447        new_index = sympy_subs(index, {x: tree_node.expr})
448        new_index = V.graph.sizevars.combine_modular_indexing_pairs(new_index)
449        # the index now contains xindex/etc, which is nonstandard, fix it up
450        return sympy_subs(
451            new_index,
452            {
453                tree_node.root.index_sym(): tree_node.root.lookup(
454                    sympy.Integer(1), tree_node.root.numel
455                ).symbol()
456            },
457        )
458
459    def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
460        if expand_res := V.graph.sizevars.expand_floor_div(index):
461            new_index, denominator = expand_res  # type: ignore[misc]
462            return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator)
463        else:
464            return self._combine_contiguous_dims(index, tree)
465
466    def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
467        """
468        More aggressive simplification to merge contiguous dims
469        """
470        if isinstance(index, (sympy.Integer, sympy.Symbol)):
471            return index
472        index_vars, sizes = tree.vars_and_sizes(index)
473        if len(sizes) <= 1:
474            return index
475        new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
476            index_vars, sizes, index_prevent_reordering([index], index_vars, sizes)
477        )
478        if new_sizes == sizes:
479            return index
480        new_index_vars = tree.construct(new_sizes)
481        new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars))))
482        return new_index
483
484    def set_last_usage(self, nodes):
485        if not self.inside_reduction or self.persistent_reduction:
486            return
487        self.last_usage = OrderedSet(
488            itertools.chain.from_iterable(
489                n.last_usage for n in nodes if n is not EnableReduction
490            )
491        )
492
493    def disable_reduction(self):
494        should_flush = self.range_trees[-1].is_loop
495
496        @contextlib.contextmanager
497        def ctx():
498            if self.numels[-1] == 1:
499                assert not self.inside_reduction
500                yield
501                return
502            if should_flush:
503                # calling codegen_body() will flush all the pending buffers
504                # and write out a reduction loop
505                self.codegen_body()
506            self.inside_reduction = False
507            try:
508                yield
509                if should_flush:
510                    # flush out any code before opening the next loop
511                    self.codegen_body()
512            finally:
513                self.inside_reduction = True
514
515        return ctx()
516
517    def set_ranges(self, *lengths):
518        assert len(lengths) == len(self.range_trees)
519        return [
520            ranges.construct(length)
521            for length, ranges in zip(lengths, self.range_trees)
522        ]
523
524    @staticmethod
525    def _split_iteration_ranges(
526        groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]
527    ):
528        sv = V.graph.sizevars
529        new_ranges: List[List[sympy.Expr]] = [[] for _ in groups]
530        remaining = [sv.simplify(g) for g in groups]
531        var_count = itertools.count()
532
533        def add_range(i, expr):
534            expr = sv.simplify(expr)
535            if not sv.statically_known_multiple_of(remaining[i], expr):
536                raise CantSplit
537            # guard on the last item out
538            remaining[i] = FloorDiv(remaining[i], expr)
539            new_ranges[i].append(expr)
540            return next(var_count)
541
542        def make_combined(size, idx1, idx2):
543            def getter(flat_vars):
544                return size * flat_vars[idx1] + flat_vars[idx2]
545
546            return getter
547
548        return_getters_groups = []
549        current_group = 0
550        for length_group in lengths:
551            return_getters = []
552            for size in length_group:
553                if sv.statically_known_equals(size, 1):  # type: ignore[arg-type]
554                    return_getters.append(lambda _: sympy.Integer(0))
555                    continue
556
557                while current_group < len(remaining) and sv.statically_known_equals(
558                    remaining[current_group], 1  # type: ignore[arg-type]
559                ):
560                    # scroll to next group with remaining elements
561                    current_group += 1
562
563                if current_group + 1 < len(remaining) and sv.statically_known_gt(
564                    size, remaining[current_group]
565                ):
566                    # need to break size in two
567                    if not sv.statically_known_multiple_of(
568                        size, remaining[current_group]
569                    ):
570                        raise CantSplit
571                    size1 = remaining[current_group]
572                    size2 = FloorDiv(size, remaining[current_group])
573                    return_getters.append(
574                        make_combined(
575                            size2,
576                            add_range(current_group, size1),
577                            add_range(current_group + 1, size2),
578                        )
579                    )
580                else:
581                    return_getters.append(
582                        operator.itemgetter(add_range(current_group, size))
583                    )
584            return_getters_groups.append(return_getters)
585
586        assert all(
587            V.graph.sizevars.size_hint(s) == 1 for s in remaining
588        ), f"failed to set ranges {remaining} {lengths}"
589
590        return new_ranges, return_getters_groups
591
592    @classmethod
593    def is_compatible(
594        cls, groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]
595    ):
596        try:
597            cls._split_iteration_ranges(groups, lengths)
598            return True
599        except CantSplit:
600            return False
601
602    def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]):
603        """
604        We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1).
605
606        To do this we need to split up the iteration space of i0 into something like:
607            for i1 in s0:
608              for i2 in s1:
609                i0 = i1*s1 + i2
610                ....
611
612        This function matches and resplits lengths to the groups of
613        this kernel to enable tiled + non-tiled fusions.
614        """
615        groups = [rt.numel for rt in self.range_trees]
616        if not self.inside_reduction:
617            groups[-1] = sympy.Integer(1)
618
619        if len(lengths) == len(self.range_trees) and all(
620            V.graph.sizevars.simplify(sympy_product(x) - g) == 0
621            for x, g in zip(lengths, groups)
622        ):
623            return self.set_ranges(*lengths)
624
625        new_ranges, return_getters_groups = self._split_iteration_ranges(
626            groups, lengths
627        )
628        itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges)))
629        return [[fn(itervars) for fn in fns] for fns in return_getters_groups]
630
631    def is_indirect_indexing(self, index: sympy.Expr):
632        # tmpX  means indirect indexing
633        return free_symbol_is_type(index, SymT.TMP)
634
635    def is_broadcasted(self, index: sympy.Expr):
636        # Note. This may not be correct when there is indirect indexing
637        if self.is_indirect_indexing(index):
638            return False
639
640        index_numels = [1] * len(self.numels)
641        for symbol in index.free_symbols:
642            if symbol not in self.range_tree_nodes:
643                # Non-iterated variables, e.g. strides
644                continue
645            entry = self.range_tree_nodes[symbol]  # type: ignore[index]
646            assert isinstance(entry.parent, IterationRangesRoot)
647            index_numels[entry.parent.index] *= entry.length
648
649        # If the index variables only iterate over a subset of the kernel
650        # numels, then it must be broadcasted.
651        simplify = V.graph.sizevars.simplify
652        return any(
653            simplify(idx_range) != simplify(iter_range)  # type: ignore[arg-type]
654            for idx_range, iter_range in zip(index_numels, self.numels)
655        )
656
657    def index_to_str(self, index: sympy.Expr) -> str:
658        """
659        Convert an index expr to a string that can be used in output code.
660        e.g. a sympy expression "s2" may actually appear as "ks1" in the generated kernel.
661
662        Index expressions often need to be passed in as arguments to the triton kernel.
663        Rename_indexing and codegen_indexing keep track of the needed indices and add
664        new parameters to the function signature.
665        """
666        if isinstance(index, list):
667            return f"[{', '.join(map(self.index_to_str, index))}]"
668        return self.kexpr(self.rename_indexing(index))  # type: ignore[call-arg]
669
670    def prepare_indexing(
671        self,
672        index: sympy.Expr,
673    ):
674        index = self.simplify_indexing(index)
675        index = sympy_subs(index, V.graph.sizevars.precomputed_replacements)
676        # if simple replacements didn't get rid of floor/ceil, try full subs
677        if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)):
678            index = index.subs(V.graph.sizevars.precomputed_replacements)
679        # last resort, if no range vars are in the expr, hoist it
680        # TODO instead of trying to blindly find complicated exprs, we should hoist the
681        # inputs/outputs sizes and strides, but at the time indexing is generated
682        # kernel inputs and outputs are not set yet, we'd need a deeper refactor
683        # to do it this way
684
685        if len(index.atoms(sympy.ceiling)):
686            for a in index.atoms(sympy.ceiling):
687                # for nested exprs, atoms yields top level first (?)
688                # so if everything goes fine, lower level replacements will come up empty
689                symbols = a.free_symbols
690                if len(symbols) > 0 and all(
691                    symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE))
692                    for s in symbols
693                ):
694                    replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)}
695                    index = sympy_subs(index, replacements)
696
697        simp_index = self.simplify_indexing(index)
698
699        # Now that we are done simplifying we can unwrap Identity so that downstream handling
700        # for its contained expression will work. previously, tl.full wrapping of sympy.Integer
701        # would not occur
702        simp_index = (
703            simp_index if not isinstance(simp_index, Identity) else simp_index.args[0]
704        )
705
706        return self.codegen_indexing(simp_index)
707
708    def active_range_trees(self, reorder=False):
709        trees = [
710            t for t in self.range_trees if t.prefix != "r" or self.inside_reduction
711        ]
712        if reorder and len(trees) > 1:
713            count = sum(t.prefix in "xyz" for t in trees)
714            assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [
715                t.prefix for t in trees[:count]
716            ]
717            trees[:count] = reversed(trees[:count])
718        return trees
719
720    def codegen_indexing(self, expr: sympy.Expr):
721        expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())
722        for sym in sorted(expr.free_symbols, key=str):
723            if sym in self.range_tree_nodes:
724                # if indexing expression is complicated, we precompute it on the host side
725                # and send the result as a kernel argument
726                replacements = {}
727                for ps in self.range_tree_nodes[sym].precomputed_args():  # type: ignore[index]
728                    replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
729                if len(replacements) > 0:
730                    self.range_tree_nodes[sym].expr = sympy_subs(  # type: ignore[index]
731                        self.range_tree_nodes[sym].expr, replacements  # type: ignore[index]
732                    )
733                self.range_tree_nodes[sym].codegen()  # type: ignore[index]
734        return expr
735
736    def codegen_nan_check(self) -> None:
737        raise NotImplementedError("NYI: codegen_nan_check")
738
739    def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None:
740        raise NotImplementedError("NYI: call_kernel")
741
742    @contextlib.contextmanager
743    def mask_loads(self, mask, value):
744        """Context manager to add an additional mask to tl.load/store"""
745        prior = self._load_mask
746        prior_val = self._load_other
747        if prior:
748            mask = ops.logical_and(mask, prior)
749
750        mask = OpsWrapper._unwrap(mask)
751        self._load_mask = mask
752        self._load_other = value
753        try:
754            # TODO(jansel): do we need a reshape here?
755            yield mask
756        finally:
757            self._load_mask = prior
758            self._load_other = prior_val
759
760    def get_strides_of_load(self, index: sympy.Expr):
761        """
762        This gets the stride of the index for each of the tiling variables
763        (technically, it does it at index 0)
764
765        For example, if
766        xindex = x0 + 512*x1 + 1024*r0
767        x0 = (xindex//512)
768        x1 = (xindex % 512)
769        r0 = rindex // 1024
770
771        this function would return
772        {xindex: 512, rindex: 1024}
773        """
774        index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()}
775        index_in_tile_vars = sympy_subs(index, index_to_tile_indexes)  # type: ignore[arg-type]
776        strides = {}
777        for range_tree in self.range_trees:
778            s = sympy_index_symbol(range_tree.name)
779            strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs(
780                index_in_tile_vars, {s: 0}
781            )
782        return strides
783
784    @staticmethod
785    def _map_tuple_or_scalar(fn, value):
786        if isinstance(value, tuple):
787            return tuple(map(fn, value))
788        return fn(value)
789
790    def estimate_kernel_num_bytes(self):
791        """
792        Try the best to estimate the total size (in bytes) of the
793        kernel's inputs and outputs, which is used for estimating the memory
794        throughput of this kernel. This information is used for checking how
795        far we are from the peak memory bandwidth. It's important that
796        we want to avoid overestimating the sizes of the inputs and outputs,
797        because it can wrongfully give us a very large memory traffic value,
798        which may be even larger than the theoretical bandwidth and thus
799        become very misleading. This is particularly problematic for cases
800        where we slice some inputs. In those cases, we should only count
801        the size of the "slices" instead of the original inputs, because
802        only the slices contribute to the real memory traffic.
803        """
804        nbytes = []
805        ninplace_args = len(unique(self.args.inplace_buffers.values()))
806        _, call_args, _, _ = self.args.python_argdefs()
807
808        # For pointwise and reduction kernels, this is the upper-bound numels
809        # for the output buffer.
810        # FIXME: This is not exactly right for cases like below:
811        #    def foo(tensor0, tensor1):
812        #        x0 = narrow(tensor0)
813        #        return cat(x0, tensor1)
814        # For this example, we will end up overestimate the size for the
815        # slice s0. Potentially, we could have precise inputs information
816        # if we maintained the original inputs of the Pointwise kernel created
817        # for the "cat". However, I think it might be a bit overwhelming that
818        # we add such complexity only for handling some particular cases for
819        # benchmarking.
820        out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels))
821        for i, arg in enumerate(call_args):
822            # "buf" may be narrowed. In this case, the number of memory accesses
823            # should be estimated based on the reinterpreted layout.
824            # On the other hand, buf may be broadcasted. In this case,
825            # counting the size of the underline storage would give us
826            # a better estimation in terms of memory accesses.
827            if arg not in self.buf_accesses:
828                nbytes.append(0)
829                continue
830            arg_numel = V.graph.get_numel(arg)
831            buf_size = V.graph.sizevars.size_hint(arg_numel)
832            if buf_size > out_numel:
833                # This arg points to a buf that has been sliced.
834                # We need to count each individual slice to have
835                # a better estimation.
836                indices: OrderedSet[Any] = OrderedSet()
837                no_index_dep_count = 0
838                for dep in self.buf_accesses[arg]:
839                    if isinstance(dep, (StarDep, WeakDep)):
840                        indices.add(f"no_index_dep_{no_index_dep_count}")
841                        no_index_dep_count += 1
842                    else:
843                        indices.add(dep.index)
844                numel = len(indices) * out_numel
845            else:
846                numel = buf_size
847            dtype = V.graph.get_dtype(arg)
848            dtype_size = get_dtype_size(dtype)
849            nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args)))
850        return sum(nbytes)
851
852    def warn_mix_layout(self, kernel_name):
853        """
854        Print message if the kernel have mixed layout inputs.
855        Only care about 4D tensor for now.
856        """
857        if (
858            len(self.args.input_buffers) == 1
859            and len(self.args.output_buffers) == 1
860            and len(self.args.inplace_buffers) == 0
861        ):
862            # even if input buffer and output buffer have different layout,
863            # this can be a layout conversion kernel. No need to warn for
864            # the mix layouts.
865            return
866
867        argdefs, call_args, signature, _ = self.args.python_argdefs()
868        uniform_stride_order = None
869        for arg_name in call_args:
870            buf = V.graph.try_get_buffer(arg_name)
871            if buf and len(buf.layout.size) == 4:
872                # ignore the tensor if only 1 dimension is non-zero
873                if len([x for x in buf.layout.size if x == 1]) == 3:
874                    continue
875                stride_order = ir.get_stride_order(buf.layout.stride)
876                if uniform_stride_order is None:
877                    uniform_stride_order = stride_order
878                elif uniform_stride_order != stride_order:
879                    msg = yellow_text(
880                        f"Expected stride order {uniform_stride_order}, but found stride order"
881                        + f" {stride_order} for kernel {kernel_name}"
882                    )
883                    log.warning(msg)
884
885                    stride_order_list = [
886                        ir.get_stride_order(V.graph.get_buffer(name).layout.stride)
887                        if V.graph.try_get_buffer(name)
888                        else None
889                        for name in call_args
890                    ]
891                    size_list = [
892                        V.graph.get_buffer(name).layout.size
893                        if V.graph.try_get_buffer(name)
894                        else None
895                        for name in call_args
896                    ]
897                    source_list = [
898                        "GraphInput"
899                        if name in V.graph.graph_inputs
900                        else "IntermediateBuffer"
901                        if name in V.graph.name_to_buffer
902                        else None
903                        for name in call_args
904                    ]
905
906                    msg = yellow_text(
907                        f"  param names {argdefs}\n  buf names {call_args}\n  strides {stride_order_list}"
908                        + f"\n  sizes {size_list}\n  sources {source_list}\n"
909                    )
910                    log.warning(msg)
911                    return
912        msg = green_text(
913            f"All the inputs for the triton kernel {kernel_name} have uniform layout"
914        )
915        log.warning(msg)
916
917    def welford_reduce_fallback(self, dtype, value):
918        sum_ = ops.reduction(dtype, dtype, "sum", value)
919        self.inside_reduction = False
920        rnumel = ops.index_expr(self.numels[-1], dtype)
921        mean = ops.truediv(sum_, rnumel)
922
923        self.inside_reduction = True
924        dx = ops.sub(value, mean)
925        dx2 = ops.mul(dx, dx)
926        m2 = ops.reduction(dtype, dtype, "sum", dx2)
927        return OpsWrapper._unwrap((mean, m2, rnumel))
928
929    def codegen_kernel(self):
930        raise NotImplementedError
931
932    def codegen_body(self):
933        pass
934
935    def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
936        pass
937
938
939class SIMDScheduling(BaseScheduling):
940    kernel_type = SIMDKernel  # override in subclass
941    int32_type = "torch.int32"
942    int64_type = "torch.int64"
943
944    def __init__(self, scheduler) -> None:
945        super().__init__()
946        self.scheduler = scheduler
947
948    def group_fn(self, sizes):
949        return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes)
950
951    def can_fuse(self, node1, node2):
952        """
953        Hook called by Scheduler to determine if the Triton backend
954        can fuse node1 and node2.  These nodes might already be
955        FusedSchedulerNodes.
956        """
957        if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance(
958            node2, scheduler.ForeachKernelSchedulerNode
959        ):
960            return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2)
961
962        _, (numel1, rnumel1) = node1.group
963        _, (numel2, rnumel2) = node2.group
964        why = WhyNoFuse(node1, node2)
965
966        if node1.is_split_scan() and not node2.is_split_scan():
967            if node2.is_reduction():
968                why("Split scan cannot fuse with reductions")
969        elif node2.is_split_scan() and not node1.is_split_scan():
970            if node1.is_reduction():
971                why("Split scan cannot fuse with reductions")
972
973        if node1.is_reduction() and node2.is_reduction():
974            reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2
975            if not reduction_can_fuse:
976                why(
977                    "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)",
978                    numel1,
979                    numel2,
980                    rnumel1,
981                    rnumel2,
982                )
983            return reduction_can_fuse
984
985        if not node1.is_reduction() and not node2.is_reduction():
986            if not (numel1 == numel2 and rnumel1 == rnumel2):
987                why(
988                    "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)",
989                    numel1,
990                    numel2,
991                    rnumel1,
992                    rnumel2,
993                )
994                return False
995
996            if node1.is_template():
997                # Only allow fusion for TritonTemplates for now.
998                # Fusion for CUDATemplates are not supported.
999                is_triton_template = isinstance(node1.node, TritonTemplateBuffer)
1000                if not is_triton_template:
1001                    why("node1 is not TritonTemplateBuffer")
1002                return is_triton_template
1003
1004            # check for a bad combined tiling
1005            tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
1006            tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1)
1007            tiling3 = self.select_tiling(
1008                node1.get_nodes() + node2.get_nodes(), numel1, rnumel1
1009            )
1010            if config.triton.tiling_prevents_pointwise_fusion:
1011                cond = True
1012                if len(tiling1) > 2:
1013                    if len(tiling2) > 2:
1014                        cond = tiling1 == tiling2 == tiling3
1015                    else:
1016                        cond = tiling1 == tiling3
1017                elif len(tiling2) > 2:
1018                    cond = tiling2 == tiling3
1019                if not cond:
1020                    why(
1021                        "tiling mismatch (%s, %s, %s)",
1022                        tiling1,
1023                        tiling2,
1024                        tiling3,
1025                    )
1026                    return False
1027
1028            return True
1029
1030        if not node1.is_reduction() and node2.is_reduction():
1031            assert rnumel1 == 1 and rnumel2 != 1
1032            if numel1 == numel2 * rnumel2:
1033                if not all(
1034                    SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges())
1035                    for n in node1.get_nodes()
1036                ):
1037                    why("nodes numel/rnumel incompatibility")
1038                    return False
1039                if (
1040                    config.triton.tiling_prevents_reduction_fusion
1041                    and not node1.is_template()
1042                ):
1043                    is_reduction_tiling_valid = self.select_tiling(
1044                        node1.get_nodes(), numel1
1045                    ) in (
1046                        (numel1, 1),
1047                        (numel2, rnumel2, 1),
1048                    )
1049                    if not is_reduction_tiling_valid:
1050                        why("invalid tiling for reduction")
1051                    return is_reduction_tiling_valid
1052                return True
1053
1054            if numel1 != numel2:
1055                why("nodes numel incompatibility")
1056            return numel1 == numel2
1057
1058        assert node1.is_reduction() and not node2.is_reduction()
1059        # swap args to hit the case above
1060        return self.can_fuse_horizontal(node2, node1)
1061
1062    can_fuse_vertical = can_fuse
1063    can_fuse_horizontal = can_fuse
1064
1065    def generate_node_schedule(self, nodes, numel, rnumel):
1066        node_schedule: List[Any] = []
1067        done: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet()
1068        # Writes with a reduced shape, meaning they are only present once the
1069        # reduction loop has ended
1070        not_ready_yet_nodes: OrderedSet[str] = OrderedSet()
1071
1072        def fits_in_main_body(n):
1073            _, (node_numel, node_rnumel) = n.group
1074            return (node_numel == numel and node_rnumel == rnumel) or (
1075                node_numel == numel * rnumel and node_rnumel == 1
1076            )
1077
1078        def fits_outside_reduction(n):
1079            _, (node_numel, node_rnumel) = n.group
1080            return node_numel == numel and node_rnumel == 1 and rnumel != 1
1081
1082        def schedule_node_in_loop(n):
1083            done.add(n)
1084            node_schedule.append(n)
1085            # A scan is modelled as a reduction in the scheduler but has a
1086            # full sized output that can be used inside the loop body
1087            if (
1088                n.is_reduction()
1089                and isinstance(n, scheduler.SchedulerNode)
1090                and isinstance(n.node, ir.ComputedBuffer)
1091                and not isinstance(n.node.data, ir.Scan)
1092            ):
1093                not_ready_yet_nodes.add(n.get_name())
1094
1095        @contextlib.contextmanager
1096        def end_current_reduction_loop():
1097            if node_schedule and node_schedule[-1] is EnableReduction:
1098                node_schedule.pop()
1099            else:
1100                node_schedule.append(DisableReduction)
1101            yield
1102            node_schedule.append(EnableReduction)
1103            not_ready_yet_nodes.clear()
1104
1105        def requires_closing_previous_reduction(node, node_schedule):
1106            if rnumel == 1:
1107                return False
1108            if not not_ready_yet_nodes & node.ancestors:
1109                return False
1110            assert node_schedule and not isinstance(
1111                node_schedule[-1], (EnableReduction, DisableReduction)
1112            )
1113            return bool(not_ready_yet_nodes)
1114
1115        for index, node in enumerate(nodes):
1116            if node in done:
1117                continue
1118            done.add(node)
1119
1120            if fits_in_main_body(node):
1121                if requires_closing_previous_reduction(node, node_schedule):
1122                    with end_current_reduction_loop():
1123                        pass  # need to start a new reduction loop
1124
1125                schedule_node_in_loop(node)
1126            elif fits_outside_reduction(node):
1127                with end_current_reduction_loop():
1128                    node_schedule.append(node)
1129            else:
1130                raise NotImplementedError(
1131                    f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}"
1132                )
1133
1134        return node_schedule
1135
1136    def codegen_node(
1137        self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode]
1138    ):
1139        """
1140        Given a set of pre-fused nodes, generate a Triton kernel.
1141        """
1142
1143        nodes: List[scheduler.SchedulerNode] = node.get_nodes()  # type: ignore[assignment]
1144
1145        _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
1146
1147        node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
1148        buf_accesses = collections.defaultdict(list)
1149        for node in nodes:
1150            for access in node.read_writes.reads | node.read_writes.writes:
1151                buf_accesses[access.name].append(access)
1152
1153        schedule_log.debug("Schedule:\n %s", node_schedule)
1154
1155        return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel)
1156
1157    @staticmethod
1158    def reduction_hint(node):
1159        assert node.is_reduction()
1160        if all(
1161            dep.is_contiguous()
1162            for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes)
1163        ):
1164            return ReductionHint.INNER
1165        else:
1166            return node.node.data.reduction_hint
1167
1168    @staticmethod
1169    def can_use_32bit_indexing(
1170        numel: sympy.Expr, buffers: Iterable[Union[ir.Buffer, ir.TensorBox]]
1171    ) -> bool:
1172        int_max = torch.iinfo(torch.int32).max
1173        size_hint = V.graph.sizevars.size_hint
1174        has_hint = V.graph.sizevars.shape_env.has_hint
1175
1176        def within_32bit(e):
1177            # Allow for unhinted e as long as we can still statically prove
1178            # (e.g., via ValueRanges) that it is still in bounds
1179            if V.graph.sizevars.is_expr_static_and_true(e <= int_max):
1180                return True
1181            # Otherwise, the hint MUST exist and be in range
1182            return has_hint(e) and size_hint(e) <= int_max
1183
1184        if not within_32bit(numel):
1185            return False
1186
1187        # Any use of a MultiOutputLayout will create a buffer with a
1188        # Layout whose sizes are accounted for
1189        buf_sizes = [
1190            buf.get_layout().storage_size()
1191            for buf in buffers
1192            if not isinstance(buf.get_layout(), ir.MultiOutputLayout)
1193        ]
1194
1195        if not all(within_32bit(size) for size in buf_sizes):
1196            return False
1197
1198        # Only install guards for 32-bit indexing as there is no correctness
1199        # issue with using 64-bit for everything
1200        V.graph.sizevars.guard_leq(numel, int_max)  # type: ignore[arg-type]
1201        for size in buf_sizes:
1202            V.graph.sizevars.guard_leq(size, int_max)  # type: ignore[arg-type]
1203        return True
1204
1205    @classmethod
1206    def select_index_dtype(cls, node_schedule, numel, reduction_numel):
1207        # Gather all used buffer names
1208        buffer_names: OrderedSet[str] = OrderedSet()
1209        for node in node_schedule:
1210            if not isinstance(node, scheduler.BaseSchedulerNode):
1211                continue
1212
1213            buffer_names.update(node.get_buffer_names())
1214            buffer_names.update(node.used_buffer_names())
1215
1216        # Get buffers objects
1217
1218        def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]:
1219            buf = V.graph.get_buffer(name)
1220            if buf is None:
1221                raise RuntimeError(f"Failed to find buffer matching name {name}")
1222            return buf
1223
1224        buffers = [V.graph.get_buffer(name) for name in buffer_names]
1225
1226        # In theory we can separately check xnumel and rnumel are <= int_max
1227        # but some indexers do use the full linear index so we need to be
1228        # conservative here.
1229        total_numel = numel * reduction_numel
1230
1231        if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers):
1232            return cls.int32_type
1233        return cls.int64_type
1234
1235    def has_non_contiguous_pw_in_reduction_kernel(self, node_schedule, numel, rnumel):
1236        pointwise_nodes = list(
1237            filter(
1238                lambda n: n not in (EnableReduction, DisableReduction)
1239                and not n.is_reduction()
1240                and n.group[1][0] == numel * rnumel,
1241                node_schedule,
1242            )
1243        )
1244        for node in pointwise_nodes:
1245            # An index can be an integer when loading a random seed.
1246            if not all(
1247                not isinstance(dep, MemoryDep)
1248                or dep.is_contiguous()
1249                or isinstance(dep.index, (sympy.Integer, int))
1250                or dep.stride1_for_last_dim()
1251                for dep in itertools.chain(
1252                    node.read_writes.reads, node.read_writes.writes
1253                )
1254            ):
1255                return True
1256        return False
1257
1258    def get_kernel_args(self, node_schedule, numel, reduction_numel):
1259        reductions = list(
1260            filter(
1261                lambda n: n not in (EnableReduction, DisableReduction)
1262                and n.is_reduction(),
1263                node_schedule,
1264            )
1265        )
1266        if len(reductions) > 0:
1267            hints = [self.reduction_hint(n) for n in reductions]
1268            if hints.count(hints[0]) == len(hints):
1269                reduction_hint_val = hints[0]
1270            else:
1271                reduction_hint_val = ReductionHint.DEFAULT
1272
1273            if (
1274                reduction_hint_val == ReductionHint.INNER
1275                and self.has_non_contiguous_pw_in_reduction_kernel(
1276                    node_schedule, numel, reduction_numel
1277                )
1278            ):
1279                reduction_hint_val = ReductionHint.DEFAULT
1280        else:
1281            reduction_hint_val = ReductionHint.DEFAULT
1282
1283        mutations: OrderedSet[str] = OrderedSet()
1284        for node in node_schedule:
1285            if node in (DisableReduction, EnableReduction):
1286                continue
1287
1288            for buf in node.get_outputs():
1289                mutations.update(buf.get_mutations())
1290
1291        index_dtype = self.select_index_dtype(node_schedule, numel, reduction_numel)
1292
1293        return reduction_hint_val, mutations, index_dtype
1294
1295    def codegen_node_schedule(
1296        self, node_schedule, buf_accesses, numel, reduction_numel
1297    ):
1298        from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel
1299
1300        tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel)
1301        (
1302            reduction_hint_val,
1303            mutations,
1304            index_dtype,
1305        ) = self.get_kernel_args(node_schedule, numel, reduction_numel)
1306
1307        is_split_scan = any(
1308            isinstance(node, BaseSchedulerNode) and node.is_split_scan()
1309            for node in node_schedule
1310        )
1311        kernel_type: type = self.kernel_type
1312        if is_split_scan and issubclass(TritonSplitScanKernel, kernel_type):
1313            kernel_type = TritonSplitScanKernel
1314
1315        kernel_args = tiled_groups
1316        kernel_kwargs = dict(
1317            reduction_hint=reduction_hint_val,
1318            mutations=mutations,
1319            index_dtype=index_dtype,
1320        )
1321
1322        def _node_has_sort(node):
1323            if node in (EnableReduction, DisableReduction):
1324                return False
1325
1326            sort_nodes = node._body.root_block.graph.find_nodes(
1327                op="call_method", target="sort"
1328            )
1329            return bool(sort_nodes)
1330
1331        # ops.sort only works with persistent reduction, and is not bandwidth bound anyway
1332        # so taking the hit of non-coalesced loads is okay
1333        has_sort = any(_node_has_sort(node) for node in node_schedule)
1334        if has_sort:
1335            kernel_kwargs["override_persistent_reduction"] = True
1336
1337        kernel = kernel_type(
1338            *kernel_args,
1339            **kernel_kwargs,
1340        )
1341        kernel.buf_accesses = buf_accesses
1342
1343        kernel2: Optional[SIMDKernel] = None
1344        if kernel.persistent_reduction and config.triton.multi_kernel and not has_sort:
1345            kernel2 = self.kernel_type(
1346                *kernel_args,
1347                **kernel_kwargs,
1348                override_persistent_reduction=False,
1349            )
1350            self.codegen_node_schedule_with_kernel(node_schedule, kernel2)
1351            with V.set_kernel_handler(kernel2):
1352                src_code2 = kernel2.codegen_kernel()
1353            kernel_name2 = self.define_kernel(src_code2, node_schedule, kernel)
1354            kernel2.kernel_name = kernel_name2
1355            kernel2.code_hash = code_hash(src_code2)
1356
1357            # Keep buffers needed by the non-persistent reduction so both
1358            # kernels have the same arguments
1359            kernel.must_keep_buffers = set(kernel2.must_keep_buffers)
1360
1361        self.codegen_node_schedule_with_kernel(node_schedule, kernel)
1362
1363        with V.set_kernel_handler(kernel):
1364            src_code = kernel.codegen_kernel()
1365
1366        kernel_name = self.define_kernel(src_code, node_schedule, kernel)
1367        log.debug("Generating kernel code with kernel_name: %s", kernel_name)
1368        kernel.kernel_name = kernel_name
1369        kernel.code_hash = code_hash(src_code)
1370
1371        final_kernel = MultiKernel([kernel, kernel2]) if kernel2 is not None else kernel
1372
1373        with V.set_kernel_handler(final_kernel):
1374            for node in node_schedule:
1375                if node not in (EnableReduction, DisableReduction):
1376                    node.mark_run()
1377
1378        self.codegen_comment(node_schedule)
1379        final_kernel.call_kernel(final_kernel.kernel_name)
1380
1381        if config.nan_asserts:
1382            final_kernel.codegen_nan_check()
1383        if config.warn_mix_layout:
1384            final_kernel.warn_mix_layout(kernel_name)
1385
1386        V.graph.removed_buffers |= final_kernel.removed_buffers
1387        V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove
1388
1389        if (
1390            V.graph.wrapper_code.supports_intermediate_hooks
1391            and config.generate_intermediate_hooks
1392        ):
1393            # Not every node in the schedule will actually be live on output;
1394            # we can't check dead buffers.
1395            live_outs = kernel.args.live_output_buffers()
1396            for node in node_schedule:
1397                if not isinstance(node, scheduler.BaseSchedulerNode):
1398                    continue
1399                name = node.get_name()
1400                if name not in live_outs:
1401                    continue
1402                assert node.node is not None
1403                origin_node = node.node.get_origin_node()
1404                if origin_node is not None:
1405                    counters["inductor"]["intermediate_hooks"] += 1
1406                    V.graph.wrapper_code.writeline(
1407                        f"run_intermediate_hooks({origin_node.name!r}, {name})"
1408                    )
1409
1410        self.scheduler.free_buffers()
1411
1412    def codegen_node_schedule_with_kernel(self, node_schedule, kernel):
1413        def current_reduction_nodes(nodes):
1414            return itertools.takewhile(lambda n: n is not DisableReduction, nodes)
1415
1416        with kernel:
1417            stack = contextlib.ExitStack()
1418            kernel.set_last_usage(current_reduction_nodes(node_schedule))
1419            all_indexing = {}
1420
1421            # First pass to collect indexing and decide inplace updates
1422            for node in node_schedule:
1423                if node is DisableReduction:
1424                    stack.enter_context(kernel.disable_reduction())
1425                elif node is EnableReduction:
1426                    stack.close()
1427                else:
1428                    node.decide_inplace_update()
1429                    index_vars = kernel.split_and_set_ranges(node.get_ranges())
1430                    all_indexing.update(
1431                        dict.fromkeys(
1432                            node._body.indexing_from_args(index_vars).values()
1433                        )
1434                    )
1435
1436            kernel.finalize_indexing(all_indexing.keys())
1437
1438            # Second pass to do codegen
1439            for i, node in enumerate(node_schedule):
1440                if node is DisableReduction:
1441                    stack.enter_context(kernel.disable_reduction())
1442                elif node is EnableReduction:
1443                    stack.close()
1444                    kernel.set_last_usage(current_reduction_nodes(node_schedule[i:]))
1445                else:
1446                    # TODO - use split ranges ?
1447                    indexing_dtype_strength_reduction(node._body)
1448                    index_vars = kernel.split_and_set_ranges(node.get_ranges())
1449                    node.codegen(index_vars)
1450
1451    def codegen_template(
1452        self, template_node, epilogue_nodes, only_gen_src_code=False
1453    ) -> Optional[str]:
1454        """
1455        Codegen a triton template
1456
1457        If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper
1458        """
1459        _, (numel, rnumel) = template_node.group
1460        assert rnumel == 1
1461        kernel, render = template_node.node.make_kernel_render(template_node.node)
1462        with kernel:
1463            if not only_gen_src_code:
1464                for node in [template_node, *epilogue_nodes]:
1465                    node.mark_run()
1466            partial_code = render()
1467            with kernel.set_subgraph_body("<STORE_OUTPUT>"):
1468                for node in epilogue_nodes:
1469                    node.codegen(kernel.split_and_set_ranges(node.get_ranges()))
1470
1471        if not isinstance(partial_code, str):
1472            partial_code.finalize_hook("<DEF_KERNEL>")
1473            partial_code.finalize_hook("<ARGDEFS>", strict=False)
1474        # finalize must be called after adding epilogue above
1475        with V.set_kernel_handler(kernel):
1476            # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
1477            with kernel.set_subgraph_body("<STORE_OUTPUT>"):
1478                if isinstance(partial_code, str):
1479                    src_code = partial_code
1480                else:
1481                    partial_code.finalize_hook("<STORE_OUTPUT>")
1482                    src_code = partial_code.code
1483            node_schedule = [template_node, *epilogue_nodes]
1484
1485            if config.benchmark_kernel:
1486                num_gb = kernel.estimate_kernel_num_bytes() / 1e9
1487                grid_args = V.graph.sizevars.size_hints(kernel.call_sizes)
1488                assert kernel.meta is not None, "meta is None"
1489                grid = kernel.grid_fn(*grid_args, kernel.meta)
1490                src_code = (
1491                    f"{kernel.imports_for_benchmark_kernel()}\n"
1492                    f"{src_code}\n"
1493                    f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}"
1494                )
1495
1496            if only_gen_src_code:
1497                return src_code
1498
1499            kernel_name = self.define_kernel(src_code, node_schedule, kernel)
1500
1501        self.codegen_comment(node_schedule)
1502        kernel.call_kernel(kernel_name, template_node.node)
1503
1504        V.graph.removed_buffers |= kernel.removed_buffers
1505        V.graph.inplaced_to_remove |= kernel.inplaced_to_remove
1506        self.scheduler.free_buffers()
1507        return None
1508
1509    def codegen_sync(self):
1510        V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize())
1511
1512    def generate_combo_kernel_code(
1513        self,
1514        subkernel_nodes: List[BaseSchedulerNode],
1515        custom_part_algorithm: bool,
1516        enable_autotune: bool,
1517        mixed_sizes: bool,
1518        only_gen_src_code: bool = False,
1519    ) -> List[Tuple[str, Any, Any]]:
1520        from .triton_combo_kernel import ComboKernel
1521
1522        fused_node_lists = [node.get_nodes() for node in subkernel_nodes]
1523        subkernel_map, node_schedule_map = {}, {}
1524        for pn, nodes in zip(subkernel_nodes, fused_node_lists):
1525            _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
1526            node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
1527            tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
1528            node_schedule_map[pn] = node_schedule, tiled_groups, numel, rnumel
1529            (
1530                reduction_hint_val,
1531                mutations,
1532                index_dtype,
1533            ) = self.get_kernel_args(node_schedule, numel, rnumel)
1534            subkernel_map[pn] = ComboKernel.create_triton_kernel(
1535                *tiled_groups,
1536                reduction_hint=reduction_hint_val,
1537                mutations=mutations,
1538                index_dtype=index_dtype,
1539                optimize_mask=not mixed_sizes,
1540            )
1541
1542        partitions = ComboKernel.horizontal_partition(
1543            nodes=subkernel_nodes,
1544            triton_scheduling=self,
1545            custom_algorithm=custom_part_algorithm,
1546            kernel_map=subkernel_map,
1547            node_info_map=node_schedule_map,
1548        )
1549        log.debug(
1550            "ComboKernels: %d nodes partitioned into %s groups",
1551            len(subkernel_nodes),
1552            [len(p) for p in partitions],
1553        )
1554        kernel_code_list = []
1555        for node_group in partitions:
1556            fused_node_lists = [node.get_nodes() for node in node_group]
1557            kernel = ComboKernel(
1558                enable_autotune=enable_autotune,
1559                mixed_sizes=mixed_sizes,
1560            )
1561
1562            for pn, nodes in zip(node_group, fused_node_lists):
1563                if only_gen_src_code:
1564                    # empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
1565                    for n in nodes:
1566                        n.last_usage = OrderedSet()
1567                self.codegen_node_schedule_with_kernel(
1568                    node_schedule_map[pn][0],
1569                    kernel.create_sub_kernel(subkernel_map[pn]),
1570                )
1571                subkernel = subkernel_map[pn]
1572                node_schedule = node_schedule_map[pn][0]
1573                if not only_gen_src_code:
1574                    with V.set_kernel_handler(subkernel):  # type: ignore[call-arg]
1575                        for node in node_schedule:
1576                            if node not in (EnableReduction, DisableReduction):
1577                                node.mark_run()
1578                V.graph.removed_buffers |= subkernel.removed_buffers
1579                V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove
1580
1581            src_code = kernel.codegen_kernel()
1582            kernel_code_list.append((src_code, kernel, node_group))
1583        return kernel_code_list
1584
1585    def codegen_combo_kernel(self, combo_kernel_node):
1586        subkernel_nodes = combo_kernel_node.get_subkernel_nodes()
1587        custom_part_algorithm = combo_kernel_node.use_custom_partition_algo
1588        enable_autotune = combo_kernel_node.enable_autotune
1589        mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or (
1590            config.combo_kernel_allow_mixed_sizes == 1 and custom_part_algorithm
1591        )
1592
1593        kernel_code_list = self.generate_combo_kernel_code(
1594            subkernel_nodes, custom_part_algorithm, enable_autotune, mixed_sizes
1595        )
1596
1597        for src_code, kernel, _ in kernel_code_list:
1598            kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel)
1599            self.codegen_comment([combo_kernel_node])
1600            log.debug("ComboKernels: generated kernel %s.", kernel_name)
1601            kernel.call_kernel(V.graph.wrapper_code, kernel_name)
1602
1603        self.scheduler.free_buffers()
1604
1605    @staticmethod
1606    @functools.lru_cache(32)
1607    def candidate_tilings(node):
1608        ranges, reduction_ranges = node.get_ranges()
1609        if len(ranges) <= 1:
1610            return ()
1611
1612        rw = node.pointwise_read_writes()
1613        assert len(rw.range_vars) == len(ranges)
1614
1615        # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads
1616        # that need to access the entire tensor; they don't contribute read indexing
1617        # information (and practically, they don't have dep.index so they can't be used
1618        # for stride_hints below
1619        dep_sources = [rw.reads, rw.writes]
1620        assert all(
1621            isinstance(dep, (MemoryDep, StarDep))
1622            for dep in itertools.chain.from_iterable(dep_sources)
1623        )
1624        deps = [
1625            dep
1626            for dep in itertools.chain.from_iterable(dep_sources)
1627            if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep)
1628        ]
1629        write_names = {dep.name for dep in rw.writes}
1630
1631        tilings: List[CandidateTiling] = []
1632
1633        for dep in deps:
1634            strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars)
1635            assert len(strides) == len(ranges)
1636            try:
1637                split = strides.index(1) + 1
1638                if split == len(ranges):
1639                    continue
1640                if all(s == 0 for s in strides[split:]):
1641                    # if this is a broadcasted tensor and all dimensions after split are broadcast,
1642                    # this is not a real split
1643                    continue
1644
1645            except ValueError:
1646                continue
1647            tiled_groups = (
1648                V.graph.sizevars.simplify(sympy_product(ranges[:split])),
1649                V.graph.sizevars.simplify(sympy_product(ranges[split:])),
1650            )
1651            # score by number of elements
1652            score = V.graph.sizevars.size_hint(
1653                sympy_product(
1654                    size for size, stride in zip(ranges, strides) if stride != 0
1655                )
1656            )
1657            if dep.name in write_names:
1658                # ngimel said contiguous writes is more important than reads
1659                score *= 2
1660            if CandidateTiling.is_good_size(tiled_groups[0]):
1661                score *= 2
1662            if CandidateTiling.is_good_size(tiled_groups[1]):
1663                score *= 2
1664
1665            if (
1666                V.graph.sizevars.size_hint(
1667                    score - sympy_product(itertools.chain(ranges, reduction_ranges))
1668                )
1669                >= 0
1670            ):
1671                tilings.append(CandidateTiling(tiled_groups, score, dep.name))
1672        return tilings
1673
1674    @classmethod
1675    def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
1676        """
1677        Heuristics to decide how to tile kernels.
1678        Currently, we tile based on stride-1 dimensions.
1679
1680        Returns:
1681            `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel`
1682
1683        """
1684        if reduction_numel != 1 or config.triton.max_tiles <= 1:
1685            # TODO(jansel): should we tile reductions?
1686            # do perf hint here if stride-1 dim is not being reduced
1687            if perf_hint_log.level <= logging.WARNING:
1688                for node in EnableReduction.filter(node_schedule):
1689                    if len(cls.candidate_tilings(node)) > 0:
1690                        perf_hint_log.info("reduction over non-contiguous dims")
1691                        break
1692            return (numel, reduction_numel)
1693
1694        seen_names: OrderedSet[str] = OrderedSet()
1695        candidate_tiles: Counter[Any] = collections.Counter()
1696        for node in EnableReduction.filter(node_schedule):
1697            for tiling in cls.candidate_tilings(node):
1698                if tiling.name in seen_names:
1699                    continue
1700                seen_names.add(tiling.name)
1701                candidate_tiles[tiling.tiling] += tiling.score
1702
1703        ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()]
1704
1705        if config.triton.max_tiles >= 3:
1706            # Consider adding a third dimension of tiling, but only
1707            # when a1 is a multiple of b1; otherwise, you have a lot
1708            # of stragglers which is annoying to generate code for.
1709            #
1710            # NB: More than three max tiles is not enabled by default.
1711
1712            # Add one 3D tiling choice
1713            for i in range(1, len(ranked_tilings)):
1714                a0, a1 = ranked_tilings[0]
1715                b0, b1 = ranked_tilings[i]
1716                if V.graph.sizevars.size_hint(a1 - b1) == 0:
1717                    continue
1718                if V.graph.sizevars.size_hint(a1 - b1) < 0:
1719                    # swap so a0 is bigger
1720                    a0, a1 = ranked_tilings[i]
1721                    b0, b1 = ranked_tilings[0]
1722                assert V.graph.sizevars.size_hint(a1 - b1) > 0
1723                if V.graph.sizevars.statically_known_multiple_of(a1, b1):
1724                    tiling = (a0, FloorDiv(a1, b1), b1)
1725                    ranked_tilings = [tiling] + ranked_tilings
1726                    break  # only 1 choice for now
1727
1728        if len(ranked_tilings) > 1:
1729            perf_hint_log.info("possibly bad tiling: %s", ranked_tilings)
1730
1731        # Optionally, prefer tiling into as many dimensions as possible.
1732        if config.triton.prefer_nd_tiling:
1733            # Get candidate tilings from the node ranges.
1734            node_ranges = [
1735                node.get_ranges()[0]
1736                for node in EnableReduction.filter(node_schedule)
1737                if isinstance(node, scheduler.SchedulerNode)
1738            ]
1739            new_tilings: OrderedSet[Tuple[sympy.Expr]] = OrderedSet()
1740            for node_range in node_ranges:
1741                # Collapse leading dims, to fit in the maximum dimensionality.
1742                num_leading_dims = max(0, len(node_range) - config.triton.max_tiles)
1743                first_trailing_dim = num_leading_dims + 1
1744                collapsed_leading_dim = sympy_product(node_range[:first_trailing_dim])
1745                tiling = [collapsed_leading_dim] + list(node_range[first_trailing_dim:])
1746                new_tilings.add(tuple(tiling))
1747
1748            # Rank tilings by the number of dimensions. E.g., prefer 2D to 1D.
1749            # Since this is a stable sort, ties are broken by schedule order.
1750            ranked_new_tilings = sorted(new_tilings, key=len, reverse=True)
1751            ranked_tilings = ranked_new_tilings + ranked_tilings
1752
1753        for tiled_groups in ranked_tilings:
1754            new_groups = (*tiled_groups, reduction_numel)
1755            if all(
1756                SIMDKernel.is_compatible(new_groups, node.get_ranges())
1757                for node in node_schedule
1758                if isinstance(node, scheduler.SchedulerNode)
1759            ):
1760                return new_groups
1761
1762        return (numel, reduction_numel)
1763
1764    def flush(self):
1765        pass
1766
1767    def ready_to_flush(self) -> bool:
1768        return False
1769
1770    def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
1771        @dataclasses.dataclass
1772        class LastUsageHolder:
1773            n: Any
1774            last_usage: Any
1775
1776            def __del__(self) -> None:
1777                self.n.last_usage = self.last_usage
1778
1779        last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes]
1780
1781        # empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
1782        for n in nodes:
1783            n.last_usage = OrderedSet()
1784
1785        if not nodes[0].is_template():
1786            _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
1787            node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
1788
1789            tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
1790            reduction_hint_val, mutations, index_dtype = self.get_kernel_args(
1791                node_schedule, numel, rnumel
1792            )
1793
1794            kernel = self.kernel_type(
1795                *tiled_groups,
1796                reduction_hint=reduction_hint_val,
1797                mutations=mutations,
1798                index_dtype=index_dtype,
1799            )
1800
1801            self.codegen_node_schedule_with_kernel(node_schedule, kernel)
1802            with config.patch(
1803                "benchmark_kernel", benchmark_kernel
1804            ), V.set_kernel_handler(kernel):
1805                src_code = kernel.codegen_kernel()
1806        else:
1807            template_node = nodes[0]
1808            epilogue_nodes = nodes[1:]
1809
1810            with config.patch("benchmark_kernel", benchmark_kernel):
1811                src_code = self.codegen_template(
1812                    template_node, epilogue_nodes, only_gen_src_code=True
1813                )
1814
1815        src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
1816        return src_code
1817
1818    def codegen_comment(self, node_schedule):
1819        pass
1820
1821    def define_kernel(self, src_code, node_schedule, kernel):
1822        raise NotImplementedError
1823
1824
1825@dataclasses.dataclass
1826class CandidateTiling:
1827    tiling: Tuple[sympy.Expr, sympy.Expr]
1828    score: int  # higher is better
1829    name: Optional[str] = None
1830
1831    @staticmethod
1832    def is_good_size(s):
1833        """Somewhat arbitrary heuristic used to boost scores for some sizes"""
1834        s = V.graph.sizevars.size_hint(s)
1835        return s >= 32 and (s % 32 == 0)
1836
1837
1838class DisableReduction:
1839    """
1840    Marker to invoke `kernel.disable_reduction()`.  This closes a
1841    reduction loop and allows for pointwise ops to occur on the output
1842    of a reduction.
1843    """
1844
1845
1846class EnableReduction:
1847    """
1848    Marker to end a DisableReduction block.
1849    """
1850
1851    @staticmethod
1852    def filter(node_schedule):
1853        """
1854        Get the nodes from node_schedule skipping those in a
1855        DisableReduction block.
1856        """
1857        disabled = False
1858        for node in node_schedule:
1859            if node in (EnableReduction, DisableReduction):
1860                # Don't tile stuff outside the main reduction loop
1861                disabled = node is DisableReduction
1862            elif disabled:
1863                pass
1864            else:
1865                yield node
1866
1867
1868class CantSplit(Exception):
1869    pass
1870