xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/triton.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import dataclasses
5import functools
6import itertools
7import logging
8import os
9import textwrap
10from functools import lru_cache
11from typing import (
12    Any,
13    Callable,
14    cast,
15    Dict,
16    Iterable,
17    List,
18    Optional,
19    Tuple,
20    TYPE_CHECKING,
21    Union,
22)
23
24import sympy
25
26import torch
27import torch._logging
28from torch._dynamo.utils import preserve_rng_state
29from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
30from torch._prims_common import is_integer_dtype
31from torch.utils._ordered_set import OrderedSet
32from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
33from torch.utils._triton import has_triton_package
34
35from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
36from ...utils._sympy.value_ranges import ValueRanges
37from .. import config, ir
38from ..codecache import code_hash, get_path, PyCodeCache
39from ..metrics import is_metric_table_enabled, log_kernel_metadata
40from ..runtime.benchmarking import benchmarker
41from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
42from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
43from ..utils import (
44    cache_on_self,
45    get_bounds_index_expr,
46    get_fused_kernel_name,
47    get_kernel_metadata,
48    is_welford_reduction,
49    Placeholder,
50    sympy_dot,
51    sympy_subs,
52)
53from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V
54from ..wrapper_benchmark import get_kernel_category_by_source_code
55from .common import (
56    BackendFeature,
57    CSE,
58    CSEVariable,
59    DeferredLine,
60    IndentedBuffer,
61    OpOverrides,
62    PythonPrinter,
63    SizeArg,
64    TensorArg,
65    WorkspaceArg,
66)
67from .simd import (
68    constant_repr,
69    IterationRangesEntry,
70    IterationRangesRoot,
71    pexpr,
72    SIMDKernel,
73    SIMDScheduling,
74)
75from .triton_utils import (
76    config_of,
77    should_unwrap_unspec_arg,
78    signature_of,
79    signature_to_meta,
80)
81
82
83if TYPE_CHECKING:
84    from ..ir import IRNode
85
86log = logging.getLogger(__name__)
87perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
88schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
89fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
90
91
92@lru_cache(None)
93def gen_attr_descriptor_import():
94    """
95    import AttrsDescriptor if the triton version is new enough to have this
96    class defined.
97    """
98    if not has_triton_package():
99        return ""
100
101    import triton.compiler.compiler
102
103    if hasattr(triton.compiler.compiler, "AttrsDescriptor"):
104        return "from triton.compiler.compiler import AttrsDescriptor"
105    else:
106        return ""
107
108
109@lru_cache(None)
110def gen_common_triton_imports():
111    imports = IndentedBuffer()
112    imports.splice(
113        """
114        import triton
115        import triton.language as tl
116        """
117    )
118    if attr_desc := gen_attr_descriptor_import():
119        imports.writeline(attr_desc)
120
121    imports.splice(
122        """
123        from torch._inductor.runtime import triton_helpers, triton_heuristics
124        from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
125        from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
126        """
127    )
128    return imports.getvalue()
129
130
131block_offsets = {
132    symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
133    for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
134}
135
136block_sizes = {
137    symt: sympy.Symbol(f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True)
138    for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX]
139}
140
141
142@dataclasses.dataclass
143class IndexingOptions:
144    index_str: str
145    mask_vars: OrderedSet[str]
146    mask_str: str
147    expand_str: Optional[str]
148    _has_rindex: bool
149    index: sympy.Expr
150
151    def has_mask(self):
152        return bool(self.mask_vars)
153
154    def has_indirect(self):
155        return free_symbol_is_type(self.index, SymT.TMP)
156
157    def has_rindex(self):
158        return self._has_rindex
159
160    def has_tmpmask(self):
161        return "tmp" in self.mask_str
162
163    def has_rmask(self):
164        return "rmask" in self.mask_str
165
166
167@dataclasses.dataclass
168class BlockPtrOptions:
169    params: BlockParameters
170    constant_offset: sympy.Expr
171    order: List[int]
172    mask_vars: OrderedSet[str]
173    reshape_suffix: List[str]
174
175    @property
176    def shape(self) -> List[sympy.Expr]:
177        return self.params.shape
178
179    @property
180    def block_shape(self) -> List[sympy.Expr]:
181        return self.params.block_shape
182
183    @property
184    def strides(self) -> List[sympy.Expr]:
185        return self.params.strides
186
187    @property
188    def offsets(self) -> List[sympy.Expr]:
189        return self.params.offsets
190
191    @staticmethod
192    def create(
193        *,
194        params: BlockParameters,
195        constant_offset: sympy.Expr,
196        range_trees: List[IterationRangesEntry],
197        mask_vars: OrderedSet[str],
198    ) -> BlockPtrOptions:
199        """Helper to create a  BlockPtrOptions instance"""
200        reshape_suffix = [f"{t.prefix.upper()}BLOCK" for t in range_trees]
201
202        # Only drop broadcast dims if the output has the same
203        # rank as the block. Otherwise, we will get shape errors.
204        drop_broadcasts = len(reshape_suffix) == len(params.strides)
205
206        broadcasting_dim = [s == 0 for s in params.strides]
207        for i, is_broadcasting in enumerate(broadcasting_dim):
208            if is_broadcasting and drop_broadcasts:
209                # drop any stride==0 dimensions for performance
210                reshape_suffix[i] = "1"
211
212        if V.kernel.no_x_dim:
213            assert range_trees[0].prefix == "x"
214            reshape_suffix.pop(0)
215
216        if (
217            not V.kernel.inside_reduction
218            and len(params.strides) == len(V.kernel.numels) - 1
219            and V.kernel.numels[-1] != 1
220        ):
221            # Need to expand rank by 1 to match rank when self.inside_reduction=True
222            reshape_suffix.append("1")
223
224        def filter(it):
225            """Removes any broadcasting dims from a given sequence"""
226            assert len(it) == len(broadcasting_dim)
227            return [
228                item
229                for item, is_broadcasting in zip(it, broadcasting_dim)
230                if not is_broadcasting or not drop_broadcasts
231            ]
232
233        # Drop broadcasting dimensions from the input.
234        params = BlockParameters(
235            **{key: filter(val) for key, val in dataclasses.asdict(params).items()}
236        )
237
238        def lookup_size(exprs: Iterable[sympy.Expr]) -> List[sympy.Expr]:
239            return [V.graph.sizevars.lookup_precomputed_size(expr) for expr in exprs]
240
241        # Look up precomputed sizes
242        params.shape = lookup_size(params.shape)
243        params.strides = lookup_size(params.strides)
244
245        return BlockPtrOptions(
246            params=params,
247            constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset),
248            order=list(reversed(range(len(params.shape)))),
249            mask_vars=mask_vars,
250            reshape_suffix=reshape_suffix,
251        )
252
253    def replace_roffset(self, expr: sympy.Expr, replacement: sympy.Expr) -> sympy.Expr:
254        """
255        Replaces instances of roffset with the new expression.
256        """
257        roffset = block_offsets[SymT.RINDEX]
258        return sympy_subs(expr, {roffset: replacement})
259
260    def format(self, name: str, roffset=True) -> str:
261        """
262        Codegen a call to tl.make_block_ptr()
263
264        Args:
265            name: variable name for pointer
266            roffset: should roffset be included in offsets=..., for use with tl.advance()
267
268        Returns:
269            "tl.make_block_ptr(...)"
270        """
271        f = V.kernel.index_to_str
272        offsets = [*self.offsets]
273        if not roffset:
274            offsets = [
275                self.replace_roffset(offset, sympy.Integer(0)) for offset in offsets
276            ]
277        args = [
278            f"{name} + ({f(self.constant_offset)})"
279            if self.constant_offset != 0
280            else name,
281            f"shape={f(self.shape)}",
282            f"strides={f(self.strides)}",
283            f"block_shape={f(self.block_shape)}",
284            f"order={f(self.order)}",
285            f"offsets={f(offsets)}",
286        ]
287        return f"tl.make_block_ptr({', '.join(args)})"
288
289    @cache_on_self
290    def boundary_check(self) -> List[int]:
291        """List of indices to pass to tl.load(boundary_check=...)"""
292        sizevars = V.graph.sizevars
293
294        # Substitute maximum block sizes in shape expressions.
295        # This works in multiple_of checks because block sizes are powers of 2.
296        block_to_max: Dict[sympy.Expr, Any] = {
297            block_size: TRITON_MAX_BLOCK[prefix_str[symt].upper()]
298            for symt, block_size in block_sizes.items()
299        }
300
301        return [
302            idx
303            for idx in range(len(self.shape))
304            if (
305                not sizevars.statically_known_equals(
306                    self.strides[idx], sympy.Integer(0)
307                )
308                and not sizevars.statically_known_multiple_of(
309                    self.shape[idx], self.block_shape[idx]
310                )
311                and not sizevars.statically_known_multiple_of(
312                    self.shape[idx], sympy_subs(self.block_shape[idx], block_to_max)
313                )
314                and not (
315                    V.kernel.no_x_dim
316                    and self.block_shape[idx] == block_sizes[SymT.XBLOCK]
317                )
318            )
319        ]
320
321    def advance_roffset(self):
322        """
323        Codegen string to pass to tl.advance(name, ...).
324
325        Advance is the difference between offsets in each loop iteration.
326        To compute it, we replace roffset with multiples of RBLOCK.
327        Since we expect roffset to vary in range(0, rnumel, RBLOCK), the first
328        iteration has roffset=0, while the second has roffset=RBLOCK.
329        """
330        rblock = block_sizes[SymT.RINDEX]
331        advance = [
332            (
333                self.replace_roffset(offset, rblock)
334                - self.replace_roffset(offset, sympy.Integer(0))
335            )
336            for offset in self.offsets
337        ]
338        return V.kernel.index_to_str(advance)
339
340    def has_indirect(self):
341        return False  # block_ptr can't do indirect indexing
342
343    def has_rindex(self) -> bool:
344        return any(free_symbol_is_type(expr, SymT.RINDEX) for expr in self.block_shape)
345
346    def has_rmask(self):
347        return self.has_rindex()
348
349    def has_tmpmask(self):
350        return False  # block_ptr can't do indirect indexing
351
352    def has_mask(self):
353        return bool(self.boundary_check())
354
355
356def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]):
357    """Workaround https://github.com/openai/triton/issues/2836"""
358    assert isinstance(old_shape, list) and isinstance(new_shape, list)
359    if old_shape == new_shape:
360        return value
361    if [s for s in new_shape if s != "1"] != old_shape:
362        return f"tl.reshape({value}, [{', '.join(new_shape)}])"
363    # rewrite to [:, None] syntax, which is less buggy
364    idx = 0
365    expand = []
366    for size in new_shape:
367        if idx < len(old_shape) and size == old_shape[idx]:
368            expand.append(":")
369            idx += 1
370        else:
371            assert size == "1"
372            expand.append("None")
373    assert idx == len(old_shape)
374    return f"{value}[{', '.join(expand)}]"
375
376
377# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a
378# number of operators which Triton "implements", but in a way that is
379# inconsistent with Python semantics (and consistent with C semantics).  We
380# must override all of these, or it is potential silent correctness problem
381class TritonPrinter(PythonPrinter):
382    def _print_TruncToInt(self, expr):
383        assert len(expr.args) == 1
384        return (
385            f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
386        )
387
388    def _print_ToFloat(self, expr):
389        assert len(expr.args) == 1
390        return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)"
391
392    def _print_PythonMod(self, expr):
393        quot, div = expr.args
394        quot_s = self._print(quot)
395        div_s = self._print(div)
396        if quot.is_nonnegative and div.is_nonnegative:
397            return f"{self.paren(quot_s)} % {self.paren(div_s)}"
398        return f"triton_helpers.remainder_integer({quot_s}, {div_s})"
399
400    def _print_FloorDiv(self, expr):
401        assert expr.is_integer
402        quot, div = expr.args
403        quot_s = self._print(quot)
404        div_s = self._print(div)
405        if quot.is_nonnegative and div.is_nonnegative:
406            return f"({self.paren(quot_s)} // {self.paren(div_s)})"
407        return f"triton_helpers.div_floor_integer({quot_s},  {div_s})"
408
409    # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
410    # precision algorithm, which we would need to replicate here
411    def _print_IntTrueDiv(self, expr):
412        lhs, rhs = expr.args
413        return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
414
415    # NB: sympy.floor/ceiling produce integers, so we have to do the
416    # conversion to index dtype
417    def _print_floor(self, expr):
418        assert len(expr.args) == 1
419        return (
420            f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
421        )
422
423    def _print_FloorToInt(self, expr):
424        assert len(expr.args) == 1
425        return (
426            f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
427        )
428
429    def _print_ceiling(self, expr):
430        assert len(expr.args) == 1
431        return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
432
433    def _print_CeilToInt(self, expr):
434        assert len(expr.args) == 1
435        return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
436
437    def _helper_sqrt(self, expr):
438        return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))"
439
440    def _print_FloatPow(self, expr):
441        return (
442            f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})"
443        )
444
445    _print_PowByNatural = _print_FloatPow
446
447    def _print_Where(self, expr):
448        c = self.doprint(expr.args[0])
449        p = self.doprint(expr.args[1])
450        q = self.doprint(expr.args[2])
451        return f"tl.where({c}, {p}, {q})"
452
453    def _print_min_max_helper(self, expr: sympy.Expr, cmp: str) -> str:
454        """
455        Helper for max/min code genereration.
456        cmp: > or <
457        """
458        nargs = len(expr.args)
459        if len(expr.args) == 1:
460            return self._print(expr.args[0])
461
462        mid = len(expr.args) // 2
463        cls = type(expr)
464        a = self._print(cls(*expr.args[:mid]))
465        b = self._print(cls(*expr.args[mid:]))
466
467        # Use a macro so we can propagate constexprs.
468        # https://github.com/triton-lang/triton/issues/3815
469        a, b = tuple(f"({x})" for x in (a, b))
470        assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'"
471        return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))"
472
473    def _print_Min(self, expr):
474        return self._print_min_max_helper(expr, "<")
475
476    def _print_Max(self, expr):
477        return self._print_min_max_helper(expr, ">")
478
479    def _print_Abs(self, expr):
480        assert len(expr.args) == 1
481        return f"tl_math.abs({self._print(expr.args[0])})"
482
483    def _print_OpaqueUnaryFn_cos(self, expr):
484        assert len(expr.args) == 1
485        return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))"
486
487    def _print_OpaqueUnaryFn_cosh(self, expr):
488        assert len(expr.args) == 1
489        return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))"
490
491    def _print_OpaqueUnaryFn_acos(self, expr):
492        assert len(expr.args) == 1
493        return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))"
494
495    def _print_OpaqueUnaryFn_sin(self, expr):
496        assert len(expr.args) == 1
497        return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))"
498
499    def _print_OpaqueUnaryFn_sinh(self, expr):
500        assert len(expr.args) == 1
501        return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))"
502
503    def _print_OpaqueUnaryFn_asin(self, expr):
504        assert len(expr.args) == 1
505        return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))"
506
507    def _print_OpaqueUnaryFn_tan(self, expr):
508        assert len(expr.args) == 1
509        return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))"
510
511    def _print_OpaqueUnaryFn_tanh(self, expr):
512        assert len(expr.args) == 1
513        return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))"
514
515    def _print_OpaqueUnaryFn_atan(self, expr):
516        assert len(expr.args) == 1
517        return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))"
518
519    def _print_RoundToInt(self, expr):
520        assert len(expr.args) == 1
521        return f"libdevice.llrint({self._print(expr.args[0])})"
522
523    def _print_RoundDecimal(self, expr):
524        assert len(expr.args) == 2
525        number, ndigits = expr.args
526        if number.is_integer:
527            # ndigits < 0 should have been filtered by the sympy function
528            assert ndigits < 0
529            raise ValueError(
530                f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
531            )
532        return f"libdevice.nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits}"
533
534
535texpr = TritonPrinter().doprint
536
537
538def triton_compute_type(dtype):
539    triton_type_name = str(dtype).split(".")[-1]
540    if triton_type_name == "bool":
541        triton_type_name = "int1"
542    elif (
543        triton_type_name in ("float16", "bfloat16")
544        and config.triton.codegen_upcast_to_fp32
545    ):
546        # float16 math is done in float32 inside the kernel
547        triton_type_name = "float32"
548    elif triton_type_name == "float8_e4m3fn":
549        triton_type_name = "float8e4nv"
550    elif triton_type_name == "float8_e5m2":
551        triton_type_name = "float8e5"
552    elif triton_type_name == "float8_e4m3fnuz":
553        triton_type_name = "float8e4b8"
554    elif triton_type_name == "float8_e5m2fnuz":
555        triton_type_name = "float8e5b16"
556    return f"tl.{triton_type_name}"
557
558
559def _get_primitive_bitwidth(dtype):
560    if hasattr(dtype, "is_floating_point"):
561        if dtype.is_floating_point:
562            # triton_compute_type changes the bitwidth
563            if (
564                dtype in [torch.bfloat16, torch.float16]
565                and config.triton.codegen_upcast_to_fp32
566            ):
567                return 32
568            return torch.finfo(dtype).bits
569        else:
570            return torch.iinfo(dtype).bits
571    else:
572        return -1
573
574
575def triton_store_type(dtype):
576    triton_type_name = str(dtype).split(".")[-1]
577    if triton_type_name == "bool":
578        triton_type_name = "int8"
579    elif triton_type_name == "float8_e4m3fn":
580        triton_type_name = "float8e4nv"
581    elif triton_type_name == "float8_e5m2":
582        triton_type_name = "float8e5"
583    return f"tl.{triton_type_name}"
584
585
586def triton_acc_type(dtype):
587    if is_integer_dtype(dtype) and dtype.is_signed:
588        nbits = 64 if dtype == torch.int64 else 32
589        return f"tl.int{nbits}"
590    return triton_compute_type(dtype)
591
592
593class TritonCSEVariable(CSEVariable):
594    def __init__(self, name, bounds: ValueRanges[Any]) -> None:
595        super().__init__(name, bounds)
596        # We'll use this to track which masks the variable needs when used for indirect indexing
597        self.mask_vars: OrderedSet[str] = OrderedSet()
598
599    def update_on_args(self, name, args, kwargs):
600        for arg in args:
601            if isinstance(arg, TritonCSEVariable):
602                self.mask_vars.update(arg.mask_vars)
603            elif isinstance(arg, sympy.Symbol) and arg.name[0] in "xyr":
604                # most of the time index vars don't need masks associated with them
605                # however, when index vars are used to compute indices for indirect reads
606                # those reads should subsequently be masked,
607                self.mask_vars.update({f"{arg.name[0]}mask"})
608
609
610class TritonOverrides(OpOverrides):
611    """Map element-wise ops to Triton"""
612
613    @staticmethod
614    def to_dtype(
615        x,
616        dtype: torch.dtype,
617        src_dtype: Optional[torch.dtype] = None,
618        use_compute_types=True,
619    ):
620        def _get_min_elements_per_thread(
621            src_dtype: torch.dtype, dst_dtype: torch.dtype
622        ) -> int:
623            if src_dtype == dst_dtype:
624                # No data type conversion is needed. No requirements on min_elem_per_thread.
625                return 0
626
627            # fp8 data type conversions has min_elem_per_thread requirements.
628            # Refer to Triton implementations here:
629            # https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10.
630            fp8_dtypes = (
631                torch.float8_e4m3fn,
632                torch.float8_e5m2,
633            )
634            # Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2.
635            assert not (
636                src_dtype in fp8_dtypes
637                and dst_dtype in fp8_dtypes
638                and src_dtype != dst_dtype
639            ), "Conversions between float8_e5m2 and float8_e4m3fn is not supported!"
640            if src_dtype == torch.float8_e5m2 or dst_dtype == torch.float8_e5m2:
641                return 4
642            if src_dtype == torch.float8_e4m3fn or dst_dtype == torch.float8_e4m3fn:
643                return 2
644            # No requirements on min_elem_per_thread.
645            return 0
646
647        if src_dtype is not None:
648            # Both dtype and src_dtype are set. This is used by torch to(dtype=dtype).
649            # It takes the maximum min_elem_per_thread if there are multiple fp8 conversions
650            # in the same kernel.
651            V.kernel.min_elem_per_thread = max(
652                _get_min_elements_per_thread(src_dtype, dtype),
653                V.kernel.min_elem_per_thread,
654            )
655
656        if dtype == torch.bool:
657            return f"({x} != 0)"
658        elif dtype == torch.uint8:
659            # to work around llvm uint conversion semantics
660            # that produces 0's for negative values
661            return f"{x}.to(tl.int8).to(tl.uint8)"
662
663        if use_compute_types:
664            out_dtype = triton_compute_type(dtype)
665        else:
666            out_dtype = triton_store_type(dtype)
667
668        return f"{x}.to({out_dtype})"
669
670    @staticmethod
671    def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
672        triton_dtype = triton_compute_type(dtype)
673        # We may promote float16 or bfloat16 to float32 and cause the
674        # bitwidth of dtype to be different from the input tensor (i.e. float32).
675        # In such as case, we will have to convert the input tensor to
676        # its src_type, perform bitcast, and then convert the bit-casted
677        # tensor back to float to ensure we use values with the right precision.
678        if (
679            src_dtype in (torch.float16, torch.bfloat16)
680            and config.triton.codegen_upcast_to_fp32
681        ):
682            triton_src_dtype = str(src_dtype).split(".")[-1]
683            cast_x = f"{x}.to(tl.{triton_src_dtype})"
684            if dtype in (torch.float16, torch.bfloat16):
685                triton_type_name = str(dtype).split(".")[-1]
686                triton_dtype = f"tl.{triton_type_name}"
687            cast_x = f"{cast_x}.to({triton_dtype}, bitcast=True)"
688            return f"{cast_x}.to(tl.float32)"
689        else:
690            src_dtype_bitwidth = _get_primitive_bitwidth(src_dtype)
691            target_dtype_bitwidth = _get_primitive_bitwidth(dtype)
692            bitcast = "True" if src_dtype_bitwidth == target_dtype_bitwidth else "False"
693            return f"{x}.to({triton_dtype}, bitcast={bitcast})"
694
695    @staticmethod
696    def _shaped_constant(value, dtype, shape):
697        type_ = torch._prims_common.dtype_to_type(dtype)
698        triton_val = constant_repr(type_(value))
699        triton_type = triton_compute_type(dtype)
700
701        if triton_type == "tl.float32":
702            # Float constants are always f32 in triton
703            return triton_val
704
705        # NOTE: We use a tensor here in order to get the expected type.
706        # Otherwise, e.g. float64 constants would be trunctated to float32.
707        return f"tl.full({shape}, {triton_val}, {triton_type})"
708
709    @classmethod
710    def constant(cls, value, dtype):
711        return cls._shaped_constant(value, dtype, shape=[])
712
713    @staticmethod
714    def abs(x):
715        return f"tl_math.abs({x})"
716
717    @staticmethod
718    def libdevice_abs(x):
719        return f"libdevice.abs({x})"
720
721    @staticmethod
722    def exp(x):
723        return f"tl_math.exp({x})"
724
725    @staticmethod
726    def libdevice_exp(x):
727        return f"libdevice.exp({x})"
728
729    @staticmethod
730    def exp2(x):
731        return f"libdevice.exp2({x})"
732
733    @staticmethod
734    def expm1(x):
735        return f"libdevice.expm1({x})"
736
737    @staticmethod
738    def sqrt(x):
739        return f"libdevice.sqrt({x})"
740
741    @staticmethod
742    def libdevice_sqrt(x):
743        return f"libdevice.sqrt({x})"
744
745    @staticmethod
746    def relu(x):
747        bug = config.triton.inject_relu_bug_TESTING_ONLY
748        if bug == "compile_error":
749            return "compile error!"
750        elif bug == "runtime_error":
751            # NB: this only triggers runtime error as long as input
752            # is not all zero
753            return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})'
754        elif bug == "accuracy":
755            return f"{x} + 1"
756        elif bug is None:
757            return ops.maximum(ops.constant(0, torch.int32), x)
758        else:
759            raise AssertionError(
760                f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}"
761            )
762
763    @staticmethod
764    def minimum(a, b):
765        return f"triton_helpers.minimum({a}, {b})"
766
767    @staticmethod
768    def maximum(a, b):
769        return f"triton_helpers.maximum({a}, {b})"
770
771    @staticmethod
772    def where(a, b, c):
773        return f"tl.where({a}, {b}, {c})"
774
775    @staticmethod
776    def inline_asm_elementwise(
777        *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1
778    ):
779        triton_type = triton_compute_type(dtype)
780        input_refs = ", ".join([str(i) for i in inputs])
781        if constraints is None:
782            constraints = ", ".join(["=r"] + ["r" for _ in inputs])
783        return f"tl.inline_asm_elementwise('{asm}', '{constraints}', [{input_refs}], dtype={triton_type}, is_pure={is_pure}, pack={pack})"  # noqa: B950
784
785    @staticmethod
786    def cos(x):
787        return f"tl_math.cos({x})"
788
789    @staticmethod
790    def libdevice_cos(x):
791        return f"libdevice.cos({x})"
792
793    @staticmethod
794    def sin(x):
795        return f"tl_math.sin({x})"
796
797    @staticmethod
798    def libdevice_sin(x):
799        return f"libdevice.sin({x})"
800
801    @classmethod
802    def index_expr(cls, expr, dtype):
803        raise NotImplementedError("ops.index_expr not implemented outside a kernel")
804
805    @staticmethod
806    def masked(mask, body, other):
807        raise NotImplementedError("ops.masked not implemented outside a kernel")
808
809    @staticmethod
810    def lgamma(x):
811        return f"libdevice.lgamma({x})"
812
813    @staticmethod
814    def erf(x):
815        return f"libdevice.erf({x})"
816
817    @staticmethod
818    def cosh(x):
819        return f"libdevice.cosh({x})"
820
821    @staticmethod
822    def sinh(x):
823        return f"libdevice.sinh({x})"
824
825    @staticmethod
826    def acos(x):
827        return f"libdevice.acos({x})"
828
829    @staticmethod
830    def acosh(x):
831        return f"libdevice.acosh({x})"
832
833    @staticmethod
834    def asin(x):
835        return f"libdevice.asin({x})"
836
837    @staticmethod
838    def asinh(x):
839        return f"libdevice.asinh({x})"
840
841    @staticmethod
842    def atan2(x, y):
843        return f"libdevice.atan2({x}, {y})"
844
845    @staticmethod
846    def atan(x):
847        return f"libdevice.atan({x})"
848
849    @staticmethod
850    def atanh(x):
851        return f"libdevice.atanh({x})"
852
853    @staticmethod
854    def copysign(x, y):
855        return f"libdevice.copysign({x}, {y})"
856
857    @staticmethod
858    def erfc(x):
859        return f"libdevice.erfc({x})"
860
861    @staticmethod
862    def erfinv(x):
863        return f"libdevice.erfinv({x})"
864
865    @staticmethod
866    def hypot(x, y):
867        return f"libdevice.hypot({x}, {y})"
868
869    @staticmethod
870    def log10(x):
871        return f"libdevice.log10({x})"
872
873    @staticmethod
874    def log2(x):
875        return f"libdevice.log2({x})"
876
877    @staticmethod
878    def nextafter(x, y):
879        return f"libdevice.nextafter({x}, {y})"
880
881    @staticmethod
882    def logical_and(a, b):
883        return f"{a} & {b}"
884
885    @staticmethod
886    def logical_not(a):
887        return f"{a} == 0"
888
889    @staticmethod
890    def logical_or(a, b):
891        return f"{a} | {b}"
892
893    @staticmethod
894    def logical_xor(a, b):
895        return f"({a} ^ {b})"
896
897    @staticmethod
898    def bitwise_and(a, b):
899        return f"{a} & {b}"
900
901    @staticmethod
902    def bitwise_not(a):
903        return f"~{a}"
904
905    @staticmethod
906    def bitwise_or(a, b):
907        return f"{a} | {b}"
908
909    @staticmethod
910    def bitwise_xor(a, b):
911        return f"{a} ^ {b}"
912
913    @staticmethod
914    def bitwise_left_shift(a, b):
915        return f"{a} << {b}"
916
917    @staticmethod
918    def bitwise_right_shift(a, b):
919        return f"{a} >> {b}"
920
921    @staticmethod
922    def rand(seed, offset):
923        offset = f"({offset}).to(tl.uint32)"
924        return f"tl.rand({seed}, {offset})"
925
926    @staticmethod
927    def randn(seed, offset):
928        offset = f"({offset}).to(tl.uint32)"
929        return f"tl.randn({seed}, {offset})"
930
931    @staticmethod
932    def randint64(seed, offset, low, high):
933        offset = f"({offset}).to(tl.uint32)"
934        return f"triton_helpers.randint64({seed}, {offset}, {low}, {high})"
935
936    @staticmethod
937    def load_seed(name, offset):
938        raise NotImplementedError("ops.load_seed not implemented outside a kernel")
939
940    @staticmethod
941    def rsqrt(x):
942        return f"libdevice.rsqrt({x})"
943
944    @staticmethod
945    def log1p(x):
946        return f"libdevice.log1p({x})"
947
948    @staticmethod
949    def tan(x):
950        return f"libdevice.tan({x})"
951
952    @staticmethod
953    def tanh(x):
954        return f"libdevice.tanh({x})"
955
956    @staticmethod
957    def sigmoid(x):
958        return f"tl.sigmoid({x})"
959
960    @staticmethod
961    def signbit(x):
962        # XX: This is wrong for the value -0.0 in floating point
963        return f"libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0"
964
965    @staticmethod
966    def fmod(a, b):
967        return f"libdevice.fmod({a}, {b})"
968
969    @staticmethod
970    def pow(a, b):
971        return f"libdevice.pow({a}, {b})"
972
973    @staticmethod
974    def log(x):
975        return f"tl_math.log({x})"
976
977    @staticmethod
978    def libdevice_log(x):
979        return f"libdevice.log({x})"
980
981    @staticmethod
982    def isinf(x):
983        return f"libdevice.isinf({x}).to(tl.int1)"
984
985    @staticmethod
986    def isnan(x):
987        return f"libdevice.isnan({x}).to(tl.int1)"
988
989    @staticmethod
990    def round(x):
991        return f"libdevice.nearbyint({x})"
992
993    @staticmethod
994    def floor(x):
995        return f"libdevice.floor({x})"
996
997    @staticmethod
998    def floordiv(a, b):
999        # See the comment in lowering.div_mode. a and b are integer type.
1000        # Similar to div_floor_kernel_cuda in pytorch core.
1001        # Notice that // in triton behaves as truncdiv instead of floordiv
1002        quot = f"{a} // {b}"
1003        rem = f"{a} % {b}"
1004        return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})"
1005
1006    @staticmethod
1007    def sign(x):
1008        z = ops.constant(0, torch.int32)
1009        left = ops.to_dtype((ops.lt(z, x)), torch.int8)
1010        right = ops.to_dtype((ops.lt(x, z)), torch.int8)
1011        sub = ops.sub(left, right)
1012        return f"{sub}.to({x}.dtype)"
1013
1014    @staticmethod
1015    def trunc(x):
1016        return f"libdevice.trunc({x})"
1017
1018    @staticmethod
1019    def truncdiv(a, b):
1020        # See the comment in lowering.div_mode. a and b are integer type.
1021        # Notice that // in triton behaves as truncdiv instead of floordiv
1022        return f"{a} // {b}"
1023
1024    @staticmethod
1025    def ceil(x):
1026        return f"libdevice.ceil({x})"
1027
1028
1029TritonOverrides._initialize_pointwise_overrides("triton")
1030
1031
1032# Use mypy to check protocol implemented correctly
1033def _typecheck_TritonOverrides(h: TritonOverrides) -> OpsHandler[str]:
1034    return h
1035
1036
1037class TritonKernelOverrides(TritonOverrides):
1038    """Map element-wise ops to Triton within a TritonKernel
1039
1040    Unlike TritonOverrides, these assume the code is going to be inserted into
1041    the body of the main triton kernel and so it may use indexing and mask
1042    variables which are assumed to already be defined in the current scope.
1043    """
1044
1045    @classmethod
1046    def constant(cls, value, dtype):
1047        # NOTE: Cannot use shape=[] as it's not supported by triton-rocm
1048        # We could use shape=[1] instead but starting with the correct
1049        # ndim avoids extra `tt.expand_dim` ops appearing in the triton IR.
1050        ndim = V.kernel.triton_tensor_ndim()
1051        shape = [1] * ndim
1052        return cls._shaped_constant(value, dtype, shape=shape)
1053
1054    @classmethod
1055    def index_expr(cls, expr, dtype):
1056        indexing = V.kernel.indexing(expr, block_ptr=False)
1057        assert isinstance(indexing, IndexingOptions)
1058        var = V.kernel.cse.generate(
1059            V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr)
1060        )
1061
1062        if dtype not in (torch.int32, torch.int64):
1063            var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype))
1064        var.mask_vars = indexing.mask_vars
1065        return var
1066
1067    @staticmethod
1068    def masked(mask, body, other):
1069        if mask is not None and torch.version.hip is not None:
1070            mask = V.kernel.cse.generate(
1071                V.kernel.compute,
1072                f"{mask}.to(tl.int1)",
1073            )
1074
1075        nodes = body.graph.find_nodes(op="output")
1076        assert nodes, "graph for body does not contain an output"
1077
1078        need_where = False
1079        for node in nodes:
1080            for arg in node.args:
1081                if arg.target != "load" or should_unwrap_unspec_arg(arg.args[0]):
1082                    need_where = True
1083
1084        value = None if need_where else other
1085        with V.kernel.mask_loads(mask, value=value) as new_mask:
1086            result = body()
1087
1088        if need_where:
1089            # Remove once CSEVariables track the dtype
1090            if result.bounds.is_bool:
1091                other = bool(other)
1092            # Take dtype from result to prevent accidental promotion
1093            other = V.kernel.cse.generate(
1094                V.kernel.compute,
1095                f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)",
1096                bounds=ValueRanges.wrap(other),
1097            )
1098            ret = ops.where(new_mask, result, other)
1099        else:
1100            ret = result
1101
1102        ret.mask_vars.discard(new_mask)
1103        return ret
1104
1105    @staticmethod
1106    def load_seed(name, offset):
1107        var = V.kernel.args.input(name)
1108        return (
1109            f"tl.load({var} + {V.kernel.args.seed_offset('load_seed_offset', offset)})"
1110        )
1111
1112    @staticmethod
1113    def frexp(x):
1114        cache_key = f"frexp({x})"
1115        if cache_key in V.kernel.cse.cache:
1116            return V.kernel.cse.cache[cache_key]
1117
1118        mantissa = V.kernel.cse.newvar()
1119        exponent = V.kernel.cse.newvar()
1120        V.kernel.compute.writeline(
1121            f"{mantissa}, {exponent} = triton_helpers.frexp({x})"
1122        )
1123        V.kernel.cse.cache[cache_key] = (mantissa, exponent)
1124        return (mantissa, exponent)
1125
1126
1127# Use mypy to check protocol implemented correctly
1128def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str]:
1129    return h
1130
1131
1132class HelperFunctions:
1133    """An ordered set of helper functions."""
1134
1135    _templates_seen: Dict[str, str]  # Template code to function name
1136    finalized_helpers: List[str]
1137
1138    def __init__(self) -> None:
1139        self._templates_seen = {}
1140        self.finalized_helpers = []
1141
1142    def add(self, template_code: str, *, base_name="_triton_helper_fn") -> str:
1143        """This accepts a function definition with the function name
1144        left as a format specifier e.g.
1145
1146            @triton.jit
1147            def {name}(arg0, arg1):
1148                return arg0 + arg1
1149
1150        We add the templated code to the function set and return the name
1151        assigned to that function.
1152
1153        """
1154        existing_name = self._templates_seen.get(template_code)
1155        if existing_name is not None:
1156            # Don't duplicate existing helpers
1157            return existing_name
1158
1159        name = f"{base_name}{len(self.finalized_helpers)}"
1160        self._templates_seen[template_code] = name
1161        self.finalized_helpers.append(template_code.format(name=name))
1162        return name
1163
1164    def __iter__(self):
1165        return iter(self.finalized_helpers)
1166
1167    def __getitem__(self, idx):
1168        return self.finalized_helpers[idx]
1169
1170
1171@dataclasses.dataclass
1172class BlockParameters:
1173    """
1174    Class representing ND block dimensions, for block pointer analysis.
1175    """
1176
1177    shape: List[sympy.Expr] = dataclasses.field(default_factory=list)
1178    block_shape: List[sympy.Expr] = dataclasses.field(default_factory=list)
1179    strides: List[sympy.Expr] = dataclasses.field(default_factory=list)
1180    offsets: List[sympy.Expr] = dataclasses.field(default_factory=list)
1181
1182    def __add__(self, other: BlockParameters) -> BlockParameters:
1183        """
1184        Concatenates block parameters.
1185        """
1186        cls = type(self)
1187        a, b = tuple(dataclasses.asdict(x) for x in (self, other))
1188        return cls(**{key: a[key] + b[key] for key in a})
1189
1190
1191class TritonKernel(SIMDKernel):
1192    overrides = TritonKernelOverrides  # type: ignore[assignment]
1193    helper_functions: HelperFunctions
1194    kexpr: Callable[[sympy.Expr], str] = texpr
1195    allow_block_ptr = True
1196
1197    def __init__(
1198        self,
1199        *groups,
1200        index_dtype: str,
1201        mutations: Optional[OrderedSet[str]] = None,
1202        pid_cache=None,
1203        reduction_hint=ReductionHint.DEFAULT,
1204        min_elem_per_thread=0,
1205        override_persistent_reduction=None,
1206        optimize_mask=True,
1207    ) -> None:
1208        self.optimize_mask: bool = optimize_mask
1209        super().__init__(
1210            *groups,
1211            index_dtype=index_dtype,
1212            mutations=mutations,
1213            reduction_hint=reduction_hint,
1214            pid_cache=pid_cache,
1215            override_persistent_reduction=override_persistent_reduction,
1216        )
1217        self.suffix: IndentedBuffer = IndentedBuffer()  # type: ignore[assignment]
1218        self.outside_loop_vars: OrderedSet[Any] = OrderedSet()
1219        self.min_elem_per_thread = min_elem_per_thread
1220        self.block_ptr_id = itertools.count()
1221        self.helper_functions = HelperFunctions()
1222
1223        # A set of autotuning hints to pass as part of triton_meta
1224        self.autotune_hints: OrderedSet[AutotuneHint] = OrderedSet()
1225        self.triton_meta: Optional[Dict[str, object]] = None
1226
1227        self.codegen_range_tree()
1228
1229    def _get_symt(self, tree: IterationRangesEntry) -> SymT:
1230        prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()}
1231        return prefix_to_symt[tree.prefix]
1232
1233    def _get_block_size(self, tree: IterationRangesEntry) -> sympy.Symbol:
1234        return block_sizes[self._get_symt(tree)]
1235
1236    def _get_block_offset(self, tree: IterationRangesEntry) -> sympy.Symbol:
1237        return block_offsets[self._get_symt(tree)]
1238
1239    def _max_block_size(self, tree: IterationRangesEntry) -> int:
1240        return TRITON_MAX_BLOCK[tree.prefix.upper()]
1241
1242    def codegen_range_tree(self):
1243        for tree in self.range_trees:
1244            # reduction indexing goes inside a loop
1245            if not tree.is_loop:
1246                self.iteration_ranges_codegen_header(tree, self.body)
1247        if self.inside_reduction and self.range_trees[-1].is_loop:
1248            # workaround for this issue:
1249            # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
1250            self.body.writeline(
1251                f"rbase = {self.iteration_ranges_ranges_code(self.range_trees[-1])}"
1252            )
1253
1254    def need_numel_args(self):
1255        r"""
1256        Indicate whether we need provide numel as arguments for the generated
1257        kernel calls in the benchmark.
1258
1259        Should be true for pointwise/reduction kernels but false for triton
1260        matmul kernels.
1261        """
1262        return True
1263
1264    def should_use_persistent_reduction(self) -> bool:
1265        """
1266        Heuristic to set self.persistent_reduction and add guards
1267        if needed.
1268        """
1269        if not (self.inside_reduction and config.triton.persistent_reductions):
1270            return False
1271        threshold = {
1272            ReductionHint.INNER: 1024,
1273        }.get(self.reduction_hint, 64)
1274
1275        # If multi_kernel is enabled, we do more aggressive persistent reduction.
1276        # This may result in some persistent reductions slower than the
1277        # corresponding non-persistent reductions. MultiKernel will do benchmarking
1278        # to pick the faster one.
1279        if config.triton.multi_kernel:
1280            threshold *= 16
1281        last_numel = self.numels[-1]
1282        return V.graph.sizevars.statically_known_leq(last_numel, threshold)  # type: ignore[arg-types]
1283
1284    def want_no_x_dim(self):
1285        return (
1286            self.reduction_hint == ReductionHint.INNER
1287            and self.persistent_reduction
1288            and len(self.numels) == 2
1289            and V.graph.sizevars.statically_known_geq(self.numels[-1], 256)  # type: ignore[arg-types]
1290        )
1291
1292    @property
1293    def assert_function(self) -> str:
1294        return "tl.device_assert"
1295
1296    def indexing(
1297        self,
1298        index: sympy.Expr,
1299        *,
1300        copy_shape=None,
1301        dense_indexing=False,
1302        override_mask=None,
1303        block_ptr=False,
1304    ):
1305        """
1306        Compute the index and mask to pass to tl.load() or tl.store()
1307        """
1308        index = self.prepare_indexing(index)
1309        index_vars = index.free_symbols
1310        has_rindex = False
1311
1312        mask_vars: OrderedSet[str] = OrderedSet()
1313        for var in index_vars:
1314            assert isinstance(var, sympy.Symbol)
1315            has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX)
1316            if override_mask:
1317                pass
1318            elif symbol_is_type(var, SymT.TMP):
1319                # indirect indexing
1320                cse_var = self.cse.varname_map[var.name]
1321                mask_vars.update(cse_var.mask_vars)
1322            elif symbol_is_type(
1323                var,
1324                (
1325                    SymT.UNBACKED_INT,
1326                    SymT.SIZE,
1327                    SymT.PRECOMPUTED_SIZE,
1328                    SymT.INDEX,
1329                    SymT.FLOAT,
1330                    SymT.UNBACKED_FLOAT,
1331                ),
1332            ):
1333                pass
1334            else:
1335                # var is one of xN, yN or rN
1336                assert symbol_is_type(
1337                    var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK)
1338                ), var.name
1339                mask_vars.add(f"{var.name[0]}mask")
1340
1341        need_dense = (
1342            config.triton.dense_indexing
1343            or dense_indexing
1344            or self._load_mask is not None
1345        ) and index != 0
1346
1347        have_dense = True
1348        have_loop_vars = False
1349        dense_mask_vars: OrderedSet[str] = OrderedSet()
1350
1351        for tree in self.active_range_trees():
1352            if index_vars.intersection(tree.var_list):
1353                have_loop_vars = True
1354            else:
1355                have_dense = False
1356            dense_mask_vars.add(f"{tree.prefix}mask")
1357
1358        if (
1359            block_ptr
1360            and self.allow_block_ptr
1361            and config.triton.use_block_ptr
1362            and not override_mask
1363            and not self._load_mask
1364            and len(mask_vars - dense_mask_vars) == 0
1365            and not self.is_indirect_indexing(index)
1366            and have_loop_vars
1367            # workaround https://github.com/openai/triton/issues/2821
1368            and self.index_dtype == "tl.int32"
1369        ):
1370
1371            def match_strided_block(
1372                index: sympy.Expr, range_tree: IterationRangesEntry
1373            ) -> Optional[BlockParameters]:
1374                """
1375                Matches expressions of the form:
1376                    idx = s * xindex
1377
1378                This implies stride (s,), and shape (XBLOCK,).
1379                """
1380                symbol = range_tree.symbol()
1381                stride = sympy.Wild("stride", exclude=[symbol])
1382                m = index.match(symbol * stride)
1383                if m is None:
1384                    return None
1385
1386                return BlockParameters(
1387                    shape=[range_tree.numel],
1388                    block_shape=[self._get_block_size(range_tree)],
1389                    strides=[m[stride]],
1390                    offsets=[self._get_block_offset(range_tree)],
1391                )
1392
1393            def match_mod_div_block(
1394                index: sympy.Expr, range_tree: IterationRangesEntry
1395            ) -> Optional[BlockParameters]:
1396                """
1397                Matches higher-dimensional blocks coming from FloorDiv and ModularIndexing.
1398
1399                Example expression to match:
1400                   sN * ((rindex//(d1 * ... * d(N-1))))
1401                       + s1 * ModularIndexing(rindex, 1, d1)
1402                       + ...
1403                       + s(N-1) * ModularIndexing(rindex, d1 * ... * d(N-2), d(N-1))
1404
1405                This iterates over a block of shape (dN, ..., d1) and stride
1406                (sN, ..., s1). (d1,...,d(N-1)) and (s1,...,sN) are
1407                wildcards that we match.
1408
1409                Note that dN does not appear in the expression, but we solve for it
1410                using range tree numels and the other dims.
1411                """
1412                # Bound the possible number of dims. We use the following heuristics:
1413                # - At least one dim for each range tree node.
1414                # - At least one dim for every FloorDiv or ModularIndexing op.
1415                # - At least 2 dims to pattern match.
1416                num_dims = max(
1417                    2,
1418                    len(self.range_tree_nodes),
1419                    (index.count(FloorDiv) + index.count(ModularIndexing)),
1420                )
1421
1422                # Pattern match to find the strides and offset.
1423                index_var = range_tree.symbol()
1424                wild = functools.partial(sympy.Wild, exclude=[index_var])
1425                dims: List[sympy.Expr] = [
1426                    wild(f"dim_mod{idx}") for idx in range(num_dims)
1427                ]
1428                strides: List[sympy.Expr] = [
1429                    wild(f"stride_mod{idx}") for idx in range(num_dims)
1430                ]
1431
1432                def get_slice_numels(dims: List[Any]) -> List[Any]:
1433                    """
1434                    Compute the cumulative size of each dimension's slice.
1435                    This proceeds from the last dim up to the second.
1436                    """
1437                    numels = [sympy.Integer(1)]
1438                    for dim in dims[:0:-1]:
1439                        numel = dim * numels[0]
1440                        numels.insert(0, numel)
1441                    return numels
1442
1443                # The first dimension's index is computed by division.
1444                # The remaining are computed by modulo.
1445                slice_numels = get_slice_numels(dims[:num_dims])
1446                block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [
1447                    ModularIndexing(index_var, numel, dim)
1448                    for dim, numel in zip(dims[1:], slice_numels[1:])
1449                ]
1450
1451                # Calculate a linear index from block indices.
1452                match_expr = sympy_dot(strides, block_index_exprs)
1453
1454                # Pattern match.
1455                match = index.match(match_expr)
1456                if match is None:
1457                    return None
1458
1459                # Provide default values for unmatched dims and strides.
1460                for dim in dims[1:]:
1461                    if dim not in match:
1462                        match[dim] = sympy.Integer(1)
1463                for stride in strides[1:]:
1464                    if stride not in match:
1465                        match[stride] = sympy.Integer(0)
1466
1467                sizevars = V.graph.sizevars
1468
1469                def get_match(expr: sympy.Expr) -> sympy.Expr:
1470                    return sizevars.lookup_precomputed_size(match[expr])
1471
1472                # Replace wildcards with matched expressions.
1473                dims = [dims[0]] + [get_match(dim) for dim in dims[1:]]
1474                strides = [get_match(stride) for stride in strides]
1475                slice_numels = get_slice_numels(dims)
1476                block_index_exprs = [
1477                    sympy_subs(expr, match) for expr in block_index_exprs
1478                ]
1479
1480                # The leading dimension is not directly matched in our expression.
1481                # We solve for it by dividing the range tree numel by the product of
1482                # all other dimensions. We quit if they are not known to be divisible.
1483                assert (
1484                    dims[0] not in match
1485                ), "Expected not to match the leading dimension!"
1486                if not sizevars.statically_known_multiple_of(
1487                    range_tree.numel, slice_numels[0]
1488                ):
1489                    return None
1490                dims[0] = range_tree.numel / slice_numels[0]
1491
1492                # Check for applicable iteration range sizes.
1493                # When mapping a 1D block into an ND one, we need to know that
1494                # the number of elements is not changed. This means the slice numels of
1495                # the ND iteration range must evenly divide the length of the 1D block.
1496                # There are two cases where we can guarantee this:
1497                #  1. Numels are powers of 2. If numel == 2 ** n, and we know XBLOCK == 2 ** m,
1498                #     with n and m integers, then either numel is a multiple of XBLOCK, or numel
1499                #     is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.)
1500                #  2. Numels are multiples of the maximum possible block size.
1501                max_block = self._max_block_size(range_tree)
1502                if any(
1503                    not sizevars.statically_known_multiple_of(numel, max_block)
1504                    and not sizevars.statically_known_power_of_2(numel)
1505                    for numel in slice_numels
1506                ):
1507                    return None
1508
1509                def identity(expr: sympy.Expr) -> sympy.Expr:
1510                    return expr
1511
1512                # Compute the ND block shape from the linear block size.
1513                # Use CielDiv to round leading dimensions up to 1.
1514                # Non-leading dimensions are clamped to the size of the iteration range,
1515                # while the leading dimension can exceed this to accomodate a larger
1516                # block size.
1517                linear_block_size = self._get_block_size(range_tree)
1518                block_shape: List[sympy.Expr] = [
1519                    CeilDiv(linear_block_size, slice_numels[0])
1520                ] + [
1521                    sympy.Min(CeilDiv(linear_block_size, numel), dim)
1522                    for numel, dim in zip(slice_numels[1:], dims[1:])
1523                ]
1524
1525                # Compute block offsets from {xyzr}offset and the matched expressions.
1526                block_offsets: List[sympy.Expr] = [
1527                    sympy_subs(expr, {index_var: self._get_block_offset(range_tree)})
1528                    for expr in block_index_exprs
1529                ]
1530
1531                return BlockParameters(
1532                    shape=dims,
1533                    block_shape=block_shape,
1534                    strides=strides,
1535                    offsets=block_offsets,
1536                )
1537
1538            def match_block_pointer_subexpr(
1539                expr: sympy.Expr, range_tree: IterationRangesEntry
1540            ) -> Optional[BlockParameters]:
1541                """
1542                Match a block indexing subexpression involving a single range tree.
1543                """
1544                for match_func in (
1545                    match_strided_block,
1546                    match_mod_div_block,
1547                ):
1548                    match = match_func(expr, range_tree)
1549                    if match is not None:
1550                        return match
1551
1552                return None
1553
1554            def match_block_pointer() -> Optional[BlockPtrOptions]:
1555                index_relative_to_xyr_index = sympy_subs(
1556                    index, {v: t.expr for v, t in self.range_tree_nodes.items()}
1557                )
1558                range_trees = self.active_range_trees(reorder=True)
1559
1560                # Match each range tree separately.
1561                range_symbols = {tree.symbol() for tree in range_trees}
1562                index_terms = sympy.Add.make_args(index_relative_to_xyr_index)
1563                block_params = BlockParameters()
1564                for tree in range_trees:
1565                    # Partition the index into subexpressions pertaining to each range tree.
1566                    # For example xindex * 5 + rindex * 3 is partitioned to
1567                    # (xindex * 5, rindex * 3).
1568                    symbol = tree.symbol()
1569                    subexpr = sympy.Integer(0) + sum(
1570                        expr for expr in index_terms if symbol in expr.free_symbols
1571                    )
1572
1573                    # Reject mixed terms, e.g. xindex * rindex.
1574                    # NB: the zero expression is allowed, for broadcasting.
1575                    if len(range_symbols.intersection(subexpr.free_symbols)) > 1:
1576                        return None
1577
1578                    # Match the subexpression for this range tree.
1579                    params = match_block_pointer_subexpr(subexpr, tree)
1580                    if params is None:
1581                        return None
1582                    block_params += params
1583
1584                # Collect leftover terms as a constant offset.
1585                offset = sum(
1586                    expr
1587                    for expr in index_terms
1588                    if not range_symbols.intersection(expr.free_symbols)
1589                )
1590
1591                # Form the block pointer.
1592                self.filter_masks(mask_vars)
1593                return BlockPtrOptions.create(
1594                    params=block_params,
1595                    constant_offset=offset,
1596                    range_trees=range_trees,
1597                    mask_vars=mask_vars,
1598                )
1599
1600            # Return a block pointer, if indexing matches the pattern.
1601            options = match_block_pointer()
1602            if options is not None:
1603                return options
1604
1605        expand_str = None
1606        index_str = self.index_to_str(index)
1607        if isinstance(index, sympy.Integer):
1608            expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
1609            index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
1610            return IndexingOptions(
1611                index_str, OrderedSet(), "None", expand_str, has_rindex, index
1612            )
1613
1614        if need_dense and not have_dense:
1615            expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
1616            index_str = f"tl.broadcast_to({index_str}, {expand_str})"
1617            mask_vars = dense_mask_vars
1618        elif not have_loop_vars and copy_shape:
1619            index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)"
1620            mask_vars = dense_mask_vars
1621
1622        if override_mask:
1623            mask_vars = OrderedSet([override_mask])
1624
1625        if self._load_mask:
1626            mask_vars.add(self._load_mask)
1627
1628        self.filter_masks(mask_vars)
1629
1630        mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None"
1631        return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex, index)  # type: ignore[arg-type]
1632
1633    def codegen_block_ptr(
1634        self, name: str, var: str, indexing: BlockPtrOptions, other=""
1635    ) -> Tuple[str, Optional[DeferredLine], str]:
1636        advance_block_ptr = None
1637        check = indexing.boundary_check()
1638        if not check:
1639            # workaround https://github.com/openai/triton/issues/2813
1640            other = ""
1641        elif other:
1642            assert other == ", other=0.0"
1643            other = f", boundary_check={check!r}, padding_option='zero'"
1644        else:
1645            other = f", boundary_check={check!r}"
1646        if (
1647            self.inside_reduction
1648            and self.range_trees[-1].is_loop
1649            and indexing.has_rindex()
1650        ):
1651            block_ptr = f"block_ptr{next(self.block_ptr_id)}"
1652            self.body.writeline(
1653                DeferredLine(
1654                    name, f"{block_ptr} = {indexing.format(var, roffset=False)}"
1655                )
1656            )
1657            advance_block_ptr = DeferredLine(
1658                name,
1659                f"{block_ptr} = tl.advance({block_ptr}, {indexing.advance_roffset()})",
1660            )
1661        else:
1662            block_ptr = indexing.format(var)
1663        return block_ptr, advance_block_ptr, other
1664
1665    def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
1666        # broadcasting is not implicit for block_ptrs
1667        value = (
1668            f"tl.broadcast_to({value}, {self.index_to_str(indexing.reshape_suffix)})"
1669        )
1670        # drop any extra size=1 dimensions
1671        block_shape = [V.kernel.index_to_str(expr) for expr in indexing.block_shape]
1672        value = triton_reshape(value, indexing.reshape_suffix, block_shape)
1673        # workaround https://github.com/openai/triton/issues/2814
1674        value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})"
1675        return f"tl.store({block_ptr}, {value}{other})"
1676
1677    def check_bounds(
1678        self,
1679        expr: sympy.Expr,
1680        size: sympy.Expr,
1681        lower: bool,
1682        upper: bool,
1683    ):
1684        if not (lower or upper):
1685            return
1686
1687        assert isinstance(expr, sympy.Expr)
1688        indexing = self.indexing(expr, block_ptr=False)
1689        assert isinstance(indexing, IndexingOptions)
1690
1691        index_str = indexing.index_str
1692        mask_str = indexing.mask_str if indexing.has_mask() else None
1693        size_str = texpr(self.rename_indexing(size)) if upper else None
1694
1695        # expr is already wrapped
1696        line = self.indirect_assert(
1697            index_str, "0" if lower else None, size_str, mask_str
1698        )
1699
1700        indirect = self.is_indirect_indexing(expr) or any(
1701            isinstance(m, TritonCSEVariable) for m in indexing.mask_vars
1702        )
1703        buffer = self.get_load_buffer(indexing)
1704        self.cse.generate(buffer, line, assignment=False)
1705
1706    def get_load_buffer(self, indexing):
1707        if indexing.has_indirect() or indexing.has_tmpmask():
1708            # Masked loads must come after the mask is computed
1709            return self.compute
1710        elif (
1711            self.inside_reduction
1712            and self.range_trees[-1].is_loop
1713            and not indexing.has_rindex()
1714        ):
1715            # can lift a common load outside of reduction loop
1716            # One exception is when this is an indirect_load.
1717            return self.body
1718        else:
1719            return self.loads
1720
1721    def load(self, name: str, index: sympy.Expr):
1722        var = self.args.input(name)
1723        indirect_indexing = self.is_indirect_indexing(index)
1724        original_index = index
1725        indexing = self.indexing(index, block_ptr=True)
1726        has_rindex = indexing.has_rindex()
1727        has_tmpmask = indexing.has_tmpmask()
1728
1729        # Keep the variable in cache if were going to reuse it. Equiv., if any of the following hold
1730        #  1) We are doing broadcasting
1731        #  2) It is a non-coalesced load. The intuition is that if it's
1732        #  non-coalesced, we will likely load each element multiple times in
1733        #  practice.
1734        #  3) It will be used later and it won't be CSE'd. Equiv., if all the following hold
1735        #   3.1) We are in a reduction loop
1736        #   3.2) Its not its last use
1737        #   3.3) This load will not be lifted to the body
1738        #
1739        is_coalesced = any(
1740            i == 1 for i in self.get_strides_of_load(original_index).values()
1741        )
1742        if self.is_broadcasted(original_index):
1743            ep = ", eviction_policy='evict_last'"
1744        elif not is_coalesced:
1745            ep = ", eviction_policy='evict_last'"
1746        elif self.inside_reduction and self.range_trees[-1].is_loop:
1747            if name in self.args.inplace_buffers:
1748                names: OrderedSet[str] = OrderedSet(
1749                    self.args.inplace_buffers[name].other_names
1750                )
1751            else:
1752                names = OrderedSet([name])
1753            last_use = len(names & self.last_usage) > 0
1754            evict_last = not last_use and (has_rindex or indirect_indexing)
1755            if evict_last:
1756                ep = ", eviction_policy='evict_last'"
1757            else:
1758                ep = ", eviction_policy='evict_first'"
1759        else:
1760            ep = ""
1761
1762        if (has_tmpmask or has_rindex) and indexing.has_mask():
1763            if self._load_other:
1764                other = f", other={constant_repr(self._load_other)}"
1765            else:
1766                other = ", other=0.0"
1767        else:
1768            other = ""
1769
1770        advance_block_ptr = None
1771        append_broadcast = None
1772        if should_unwrap_unspec_arg(name):
1773            line = var
1774        else:
1775            if isinstance(indexing, BlockPtrOptions):
1776                block_ptr, advance_block_ptr, other = self.codegen_block_ptr(
1777                    name, var, indexing, other
1778                )
1779                line = f"tl.load({block_ptr}{other}{ep})"
1780                # add needed size=1 dimensions
1781                block_shape = [str(dim) for dim in indexing.block_shape]
1782                line = triton_reshape(line, block_shape, indexing.reshape_suffix)
1783            elif isinstance(original_index, sympy.Integer):
1784                line = f"tl.load({var} + ({original_index}))"
1785                append_broadcast = indexing.expand_str
1786            else:
1787                line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})"
1788
1789            dtype = V.graph.get_dtype(name)
1790            if (
1791                dtype in (torch.float16, torch.bfloat16)
1792                and config.triton.codegen_upcast_to_fp32
1793            ):
1794                line += ".to(tl.float32)"
1795            if dtype == torch.bool and torch.version.hip is None:
1796                # Workaround for https://github.com/openai/triton/issues/2151
1797                # tl.load returns int8 when loading from pointer to int1
1798                # NOTE: Currently causes hangs on bool UTs for ROCm
1799                line += ".to(tl.int1)"
1800
1801        load_buffer = self.get_load_buffer(indexing)
1802        result_var = self.cse.generate(load_buffer, line)
1803        assert isinstance(result_var, TritonCSEVariable)
1804        result_var.mask_vars = indexing.mask_vars  # type: ignore[assignment]
1805
1806        if append_broadcast:
1807            line = f"tl.broadcast_to({result_var}, {append_broadcast})"
1808            result_var = self.cse.generate(load_buffer, line)
1809
1810        if advance_block_ptr:
1811            load_buffer.writeline(advance_block_ptr)
1812
1813        if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex):
1814            self.outside_loop_vars.add(result_var)
1815
1816        return result_var
1817
1818    def store(
1819        self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
1820    ) -> None:
1821        var = self.args.output(name)
1822        original_index = index
1823        indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None)
1824
1825        # Guard against write-after-read corruption in triton.
1826        # See # https://github.com/openai/triton/issues/1615
1827        # This triton bug means that a load which is broadcasted over multiple
1828        # warps may see the result of a store that happens later in the triton
1829        # program. The workaround is to add a barrier before storing, which
1830        # enforces that all warps have already read the data.
1831        is_inplace = name in self.args.inplace_buffers
1832        is_broadcasted = self.is_broadcasted(original_index)
1833        if is_inplace and is_broadcasted:
1834            self.stores.writeline(DeferredLine(name, "tl.debug_barrier()"))
1835
1836        advance_block_ptr = None
1837        if isinstance(indexing, BlockPtrOptions):
1838            block_ptr, advance_block_ptr, other = self.codegen_block_ptr(
1839                name, var, indexing
1840            )
1841            # block_ptr stores don't do implicit casting
1842            line = self.codegen_block_ptr_store_line(
1843                name, indexing, block_ptr, value, other
1844            )
1845        elif mode is None:
1846            line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})"
1847        elif mode == "atomic_add":
1848            line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str}, sem='relaxed')"
1849        else:
1850            raise NotImplementedError(f"store mode={mode}")
1851        self.stores.writeline(DeferredLine(name, line))
1852        if advance_block_ptr:
1853            self.stores.writeline(advance_block_ptr)
1854
1855        if not self.inside_reduction:
1856            self.outside_loop_vars.add(value)
1857
1858    def bucketize(
1859        self,
1860        values: CSEVariable,
1861        offsets_name: str,
1862        offsets_size: sympy.Expr,
1863        indexing_dtype: torch.dtype,
1864        right: bool,
1865    ) -> CSEVariable:
1866        """
1867        See [Note: Inductor bucketize op]
1868        """
1869
1870        # Triton performance for bucketize_binary_search is much better when the number
1871        # of threads equals the number of elements.
1872        # If we're trying to use a bucketize kernel, we should make sure that an
1873        # autotuning config with num_elements_per_warp=32 exists.
1874        self.autotune_hints.add(AutotuneHint.ELEMENTS_PER_WARP_32)
1875
1876        offsets_ptr = self.args.input(offsets_name)
1877        block_size = self.dense_size_str()
1878        offsets_size_str = self.index_to_str(offsets_size)
1879
1880        if indexing_dtype == torch.int32:
1881            triton_dtype = "tl.int32"
1882        elif indexing_dtype == torch.int64:
1883            triton_dtype = "tl.int64"
1884        else:
1885            raise NotImplementedError(
1886                "Bucketize only supports indexing with int32 and int64"
1887            )
1888
1889        result = self.cse.generate(
1890            self.compute,
1891            f"triton_helpers.bucketize_binary_search({values}, {offsets_ptr}, {triton_dtype}, {right}, {offsets_size_str}, {block_size})",  # noqa: B950 line too long
1892        )
1893
1894        return result
1895
1896    def reduction_resize(self, value):
1897        ndims = self.triton_tensor_ndim()
1898        if ndims == 1:
1899            return f"triton_helpers.promote_to_tensor({value})"
1900
1901        sizes = [":"] * ndims
1902        sizes[-1] = "None"
1903        return f"{value}[{', '.join(sizes)}]"
1904
1905    def reduction(
1906        self,
1907        dtype: torch.dtype,
1908        src_dtype: torch.dtype,
1909        reduction_type: ReductionType,
1910        value: Union[CSEVariable, Tuple[CSEVariable, ...]],
1911    ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
1912        assert self.inside_reduction
1913        masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
1914        self.filter_masks(masks)
1915        masks = sorted(masks)
1916        if self._load_mask:
1917            masks.append(self._load_mask)
1918        reduction_range_prefix = self.range_trees[-1].prefix
1919
1920        # Say we have
1921        #     tmp0 = ops.constant(1, torch.int64)
1922        #     tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0)
1923        # tmp0 in the triton code is either a scalar, or single-element tensor
1924        # so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1
1925        # To avoid this, we broadcast to the expected shape first.
1926        dense_size_str = self.dense_size_str()
1927        value = self._map_tuple_or_scalar(
1928            lambda v: self.cse.generate(
1929                self.compute, f"tl.broadcast_to({v}, {dense_size_str})"
1930            ),
1931            value,
1932        )
1933
1934        dim: int
1935        root_op: str
1936
1937        def final_reduction(value):
1938            use_helper = reduction_type in {"any", "max", "min", "prod"}
1939            module = "triton_helpers" if use_helper else "tl"
1940            if reduction_type in {"max", "min"}:
1941                return self.reduction_resize(
1942                    f"{module}.{reduction_type}2({value}, {dim})"
1943                )
1944            return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})")
1945
1946        def final_argreduce(buffer, result_var, value, index):
1947            buffer.splice(
1948                f"""\
1949                _, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
1950                {result_var} = {self.reduction_resize(f'{result_var}_tmp')}
1951                """
1952            )
1953
1954        cache_key = (src_dtype, reduction_type, value)
1955        if cache_key in self.cse.reduction_cache:
1956            return self.cse.reduction_cache[cache_key]
1957
1958        dim = self.triton_tensor_ndim() - 1
1959        acc_type = triton_acc_type(src_dtype)
1960        result_var: Any = self.cse.newvar()
1961        result_var.mask_vars = OrderedSet(var for var in masks if var[0] != "r")
1962        cond = " & ".join(masks)
1963
1964        def where_cond(tval, fval):
1965            if not cond:
1966                return tval
1967            return TritonKernelOverrides.where(cond, tval, fval)
1968
1969        if self.persistent_reduction:
1970            default = ir.Reduction.default_value(reduction_type, src_dtype)
1971            default = self._map_tuple_or_scalar(constant_repr, default)
1972
1973            def _mask_value(value, default):
1974                return self.cse.generate(self.compute, where_cond(value, default))
1975
1976            if isinstance(value, tuple):
1977                masked_value = [_mask_value(v, d) for v, d in zip(value, default)]
1978            else:
1979                masked_value = _mask_value(value, default)
1980
1981            if reduction_type in {"argmax", "argmin"}:
1982                accumulator_index = str(
1983                    self.cse.generate(
1984                        self.compute,
1985                        f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)",
1986                    )
1987                )
1988                root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
1989                final_argreduce(
1990                    self.compute, result_var, masked_value, accumulator_index
1991                )
1992            elif reduction_type == "welford_reduce":
1993                # For persistent reductions, don't bother with
1994                # welford's algorithm since it uses more registers, and
1995                # taking two reductions doesn't increase memory usage.
1996                result_var = self.welford_reduce_fallback(dtype, value)
1997            elif reduction_type == "welford_combine":
1998                mean, m2, weight = masked_value
1999                welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})"
2000                mean, m2, weight = (self.cse.newvar() for _ in range(3))
2001                self.compute.writeline(f"{mean}, {m2}, {weight} = {welford}")
2002
2003                result_var = tuple(
2004                    self.cse.generate(self.compute, self.reduction_resize(var_name))
2005                    for var_name in (mean, m2, weight)
2006                )
2007            else:
2008                result_var = self.cse.generate(
2009                    self.compute, final_reduction(masked_value)
2010                )
2011        else:
2012            accumulator = f"_{result_var}"
2013            default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
2014            default = self._map_tuple_or_scalar(constant_repr, default)
2015            if not isinstance(default, tuple):
2016                self.body.writeline(
2017                    f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})"
2018                )
2019
2020            if reduction_type in {"argmax", "argmin"}:
2021                accumulator_index = f"_{result_var}_index"
2022                long_max = torch.iinfo(torch.int64).max
2023                self.body.writeline(
2024                    f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)"
2025                )
2026                root_op = {"argmax": "max", "argmin": "min"}[reduction_type]
2027
2028                self.compute.splice(
2029                    f"""\
2030                {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
2031                    {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
2032                )
2033                {accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
2034                {accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)}
2035                """
2036                )
2037                final_argreduce(self.suffix, result_var, accumulator, accumulator_index)
2038            elif is_welford_reduction(reduction_type):
2039                accumulator = f"{result_var}_mean"
2040                accumulator_m2 = f"{result_var}_m2"
2041                accumulator_weight = f"{result_var}_weight"
2042                self.body.writeline(
2043                    f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})"
2044                )
2045                self.body.writeline(
2046                    f"{accumulator_m2} = tl.zeros({self.dense_size_str()}, {acc_type})"
2047                )
2048                self.body.writeline(
2049                    f"{accumulator_weight} = tl.zeros({self.dense_size_str()}, {acc_type})"
2050                )
2051
2052                if reduction_type == "welford_combine":
2053                    mean, m2, weight = value
2054                    self.compute.splice(
2055                        f"""\
2056                    {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_combine(
2057                        {accumulator}, {accumulator_m2}, {accumulator_weight},
2058                        {mean}, {m2}, {weight}
2059                    )
2060                    """
2061                    )
2062                else:
2063                    assert reduction_type == "welford_reduce"
2064                    self.compute.splice(
2065                        f"""\
2066                    {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_reduce(
2067                        {value}, {accumulator}, {accumulator_m2}, {accumulator_weight}, roffset == 0
2068                    )
2069                    """
2070                    )
2071
2072                self.compute.splice(
2073                    f"""\
2074                {accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
2075                {accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)}
2076                {accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)}
2077                """
2078                )
2079
2080                result_mean = result_var
2081                result_m2 = self.cse.newvar()
2082                result_weight = self.cse.newvar()
2083                self.suffix.splice(
2084                    f"""\
2085                {result_mean}_tmp, {result_m2}_tmp, {result_weight}_tmp = triton_helpers.welford(
2086                    {accumulator}, {accumulator_m2}, {accumulator_weight}, {dim}
2087                )
2088                {result_mean} = {self.reduction_resize(f'{result_mean}_tmp')}
2089                {result_m2} = {self.reduction_resize(f'{result_m2}_tmp')}
2090                {result_weight} = {self.reduction_resize(f'{result_weight}_tmp')}
2091                """
2092                )
2093                result_var = result_mean, result_m2, result_weight
2094            else:
2095                combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype)
2096                updated = combine_fn(accumulator, value)
2097                self.compute.writeline(
2098                    f"{accumulator} = {where_cond(updated, accumulator)}"
2099                )
2100
2101                if src_dtype == torch.bool:
2102                    # This is only really used for aten.any. It changes the
2103                    # final reduction of a non-persistent reduction from
2104                    #     tmp5 = triton_helpers.max(_tmp5, 1)[:, None]
2105                    # to
2106                    #     tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1)
2107                    # which is needed because tl.reduce doesn't support tl.int1
2108                    accumulator = f"{accumulator}.to(tl.int8)"
2109                    result_type = triton_compute_type(dtype)
2110                    self.suffix.writeline(
2111                        f"{result_var} = {final_reduction(accumulator)}.to({result_type})"
2112                    )
2113                else:
2114                    self.suffix.writeline(
2115                        f"{result_var} = {final_reduction(accumulator)}"
2116                    )
2117
2118        self.cse.reduction_cache[cache_key] = result_var
2119
2120        if isinstance(result_var, tuple):
2121            assert all(isinstance(x, TritonCSEVariable) for x in result_var)
2122            self.outside_loop_vars |= OrderedSet(result_var)
2123        else:
2124            assert isinstance(result_var, TritonCSEVariable)
2125            self.outside_loop_vars.add(result_var)
2126
2127        return result_var
2128
2129    def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
2130        assert self.inside_reduction
2131        self.inside_reduction = False
2132        indexing = self.indexing(index, block_ptr=True)
2133        self.inside_reduction = True
2134        var = self.args.output(name)
2135
2136        if isinstance(indexing, BlockPtrOptions):
2137            self.suffix.writeline(
2138                DeferredLine(
2139                    name,
2140                    self.codegen_block_ptr_store_line(
2141                        name,
2142                        indexing,
2143                        indexing.format(var),
2144                        value,
2145                        f", boundary_check={indexing.boundary_check()!r}",
2146                    ),
2147                )
2148            )
2149        else:
2150            assert isinstance(indexing, IndexingOptions)
2151            self.suffix.writeline(
2152                DeferredLine(
2153                    name,
2154                    f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})",
2155                )
2156            )
2157
2158    def _lift_helper(self, fn, num_args) -> str:
2159        # Lift IR function for scan operations into a triton function
2160        # in the global namespace
2161        helper = IndentedBuffer()
2162        helper.writeline("@triton.jit")
2163        args = [tuple(f"arg{i}_{n}" for n in range(num_args)) for i in range(2)]
2164        signature = ", ".join(itertools.chain.from_iterable(args))
2165        helper.writeline(f"def {{name}}({signature}):")
2166
2167        cse = CSE(prefix="", suffix="")
2168        overrides = TritonOverrides(V.MockHandler())
2169
2170        # Build a name that changes depending on fn to workaround a triton bug
2171        # where the combine_fn to reduce and scan is not hashed, and so different
2172        # scan ops may collide in the triton cache.
2173        # This is fixed with the latest triton pin, but not the triton-rocm pin.
2174        helper_name = "_triton_helper_fn"
2175
2176        class CSEProxy:
2177            def __getattr__(self, name: str) -> Callable[..., CSEVariable]:
2178                def inner(*args, **kwargs):
2179                    nonlocal helper_name
2180                    helper_name += f"_{name}"
2181                    return cse.generate(
2182                        helper,
2183                        getattr(overrides, name)(*args, **kwargs),
2184                    )
2185
2186                return inner
2187
2188        with helper.indent(), V.set_ops_handler(CSEProxy()):
2189            outputs = fn(*args)
2190            outputs = ", ".join(str(output) for output in outputs)
2191            helper.writeline(f"return {outputs}")
2192
2193        return self.helper_functions.add(helper.getvalue(), base_name=helper_name)
2194
2195    def scan(
2196        self,
2197        dtypes: Tuple[torch.dtype, ...],
2198        combine_fn: Callable[
2199            [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...]
2200        ],
2201        values: Tuple[CSEVariable, ...],
2202    ) -> Tuple[CSEVariable, ...]:
2203        assert self.inside_reduction
2204        masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
2205        self.filter_masks(masks)
2206        masks = sorted(masks)
2207        assert not self._load_mask, "ops.scan not supported inside ops.masked"
2208        reduction_range_prefix = self.range_trees[-1].prefix
2209
2210        broadcasted_values = []
2211        accumulators = []
2212
2213        cse_compute = functools.partial(self.cse.generate, self.compute)
2214        combine_helper_fn = self._lift_helper(combine_fn, len(values))
2215        dim = self.triton_tensor_ndim() - 1
2216
2217        for value, dtype in zip(values, dtypes):
2218            acc_type = triton_acc_type(dtype)
2219            cond = " & ".join(masks)
2220
2221            value_dtype = self.cse.generate(
2222                self.compute,
2223                f"{value}.to({triton_compute_type(dtype)})",
2224            )
2225            value = self.cse.generate(
2226                self.compute,
2227                f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})",
2228            )
2229            broadcasted_values.append(value)
2230
2231            acc_type = triton_acc_type(dtype)
2232            cond = " & ".join(masks)
2233
2234            if not self.persistent_reduction:
2235                accumulator = self.cse.newvar()
2236                reduced_size = self.dense_size_list()
2237                reduced_size[-1] = "1"
2238                reduced_size = f"[{', '.join(reduced_size)}]"
2239
2240                default = "float('nan')" if dtype.is_floating_point else "-1"
2241                self.body.writeline(
2242                    f"{accumulator} = tl.full({reduced_size}, {default}, {acc_type})"
2243                )
2244
2245                accumulators.append(accumulator)
2246
2247        def csv(values):
2248            return " ".join(f"{value}," for value in values)
2249
2250        def cse_multiple(line, n, masks):
2251            cache_keys = [f"{line}, {i}, {masks}" for i in range(n)]
2252            if all(cache_key in self.cse.cache for cache_key in cache_keys):
2253                return [self.cse.cache[cache_key] for cache_key in cache_keys]
2254            result_vars = [self.cse.newvar() for _ in range(n)]
2255            self.compute.writeline(
2256                f"{csv(result_vars)} = {line}",
2257            )
2258            for result_var, cache_key in zip(result_vars, cache_keys):
2259                if masks:
2260                    result_var.mask_vars = masks  # type: ignore[attr-defined]
2261                self.cse.cache[cache_key] = result_var
2262            return tuple(result_vars)
2263
2264        partial_scan_vars = cse_multiple(
2265            f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})",
2266            len(values),
2267            masks,
2268        )
2269
2270        if not self.persistent_reduction:
2271            # tl.reduce doesn't work for non-commutative operators, so instead
2272            # of repeating the scan op as a reduction, we use sum to select the
2273            # last scan value
2274            partial_reduce_vars = [
2275                cse_compute(
2276                    f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)"
2277                )
2278                for partial_scan_var in partial_scan_vars
2279            ]
2280            accs_next = combine_fn(tuple(accumulators), tuple(partial_reduce_vars))
2281            full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars)
2282            result_vars = [
2283                cse_compute(f"tl.where(roffset > 0, {full_scan}, {partial_scan})")
2284                for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars)
2285            ]
2286            for acc_next, accumulator, partial_reduce in zip(
2287                accs_next, accumulators, partial_reduce_vars
2288            ):
2289                self.compute.writeline(
2290                    f"{accumulator} = tl.where(roffset > 0, {acc_next}, {partial_reduce})"
2291                )
2292        else:
2293            result_vars = partial_scan_vars
2294
2295        for result_var in result_vars:
2296            result_var.mask_vars = masks  # type: ignore[attr-defined]
2297
2298        return tuple(result_vars)
2299
2300    def sort(
2301        self,
2302        dtypes: Tuple[torch.dtype, ...],
2303        values: Tuple[CSEVariable, ...],
2304        stable: bool,
2305        descending: bool,
2306    ) -> Tuple[CSEVariable, ...]:
2307        assert self.inside_reduction
2308        masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees)
2309        self.filter_masks(masks)
2310        masks = sorted(masks)
2311        assert not self._load_mask, "ops.sort not supported inside ops.masked"
2312        assert (
2313            self.persistent_reduction
2314        ), "ops.sort is only supported in persistent reductions"
2315        reduction_range_prefix = self.range_trees[-1].prefix
2316
2317        cse_compute = functools.partial(self.cse.generate, self.compute)
2318        dim = self.triton_tensor_ndim() - 1
2319
2320        broadcasted_values = [
2321            cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})")
2322            for value in values
2323        ]
2324
2325        def csv(values):
2326            return " ".join(f"{value}," for value in values)
2327
2328        def cse_multiple(line, n, masks):
2329            cache_keys = [f"{line}, {i}, {masks}" for i in range(n)]
2330            if all(cache_key in self.cse.cache for cache_key in cache_keys):
2331                return [self.cse.cache[cache_key] for cache_key in cache_keys]
2332            result_vars = [self.cse.newvar() for _ in range(n)]
2333            self.compute.writeline(
2334                f"{csv(result_vars)} = {line}",
2335            )
2336            for result_var, cache_key in zip(result_vars, cache_keys):
2337                if masks:
2338                    result_var.mask_vars = masks  # type: ignore[attr-defined]
2339                self.cse.cache[cache_key] = result_var
2340            return tuple(result_vars)
2341
2342        assert self.range_trees[-1].prefix == "r"
2343        rnumel = "None" if self._has_constant_mask(self.range_trees[-1]) else "rnumel"
2344
2345        if len(values) == 2:
2346            line = (
2347                f"triton_helpers.sort_with_index({broadcasted_values[0]}, {broadcasted_values[1]},"
2348                f" {rnumel}, {dim}, stable={stable}, descending={descending})"
2349            )
2350            result_vars = cse_multiple(line, len(values), masks)
2351        else:
2352            raise AssertionError("Unhandled sort")
2353
2354        for result_var, input_var in zip(result_vars, values):
2355            result_var.mask_vars = masks  # type: ignore[attr-defined]
2356            result_var.bounds = input_var.bounds
2357
2358        return tuple(result_vars)
2359
2360    def codegen_body(self):
2361        """
2362        Concat output code from index_code, loads, compute, stores,
2363        suffix into self.body.
2364
2365        For pointwise kernels, this is called just once at the end.
2366
2367        For reduction kernels, this generates a loop over the reduction
2368        axis.
2369        """
2370        if not (
2371            self.indexing_code
2372            or self.loads
2373            or self.stores
2374            or self.compute
2375            or self.suffix
2376        ):
2377            return
2378
2379        if self.inside_reduction and self.range_trees[-1].is_loop:
2380            self.body.writeline("for roffset in range(0, rnumel, RBLOCK):")
2381            with self.body.indent():
2382                # last range tree is always reduction
2383                self.iteration_ranges_codegen_header(self.range_trees[-1], self.body)
2384                self.body.splice(self.indexing_code)
2385                self.body.splice(self.loads)
2386                self.body.splice(self.compute)
2387                self.body.splice(self.stores)
2388
2389            # invalidate any caches that came from inside the reduction loop
2390            self.cse.invalidate(self.outside_loop_vars)
2391            self.range_trees[-1].cache_clear()
2392        else:
2393            self.body.splice(self.indexing_code)
2394            self.body.splice(self.loads)
2395            self.body.splice(self.compute)
2396            self.body.splice(self.stores)
2397        self.body.splice(self.suffix)
2398        self.indexing_code.clear()
2399        self.loads.clear()
2400        self.compute.clear()
2401        self.stores.clear()
2402        self.suffix.clear()
2403
2404    def codegen_kernel_benchmark(self, num_gb, grid=None):
2405        result = IndentedBuffer()
2406        argdefs, call_args, signature, _ = self.args.python_argdefs()
2407
2408        result.writelines(["", "", "def get_args():"])
2409        with result.indent():
2410            name_cnt = itertools.count()
2411            var_names = []
2412            for arg_name, arg_sig in zip(call_args, signature):
2413                var_name = f"arg_{next(name_cnt)}"
2414                buf = V.graph.try_get_buffer(arg_name)
2415                if buf:
2416                    result.writeline(
2417                        f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})"  # noqa: B950 line too long
2418                    )
2419                elif arg_name in V.graph.constants:
2420                    # note that random seed is put in V.graph.constants
2421                    const_tensor = V.graph.constants[arg_name]
2422                    result.writeline(
2423                        f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})"  # type: ignore[arg-type]  # noqa: B950 line too long
2424                    )
2425                elif isinstance(arg_sig, SizeArg):
2426                    symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
2427
2428                    # Force the seed_offset to be 0 so calls to the same kernel
2429                    # using different seed offset will have the same benchmark harness.
2430                    # We can dedup kernel definitions in this case.
2431                    if "seed_offset" in arg_sig.name:
2432                        symval_hint = 0
2433                    result.writeline(f"{var_name} = {symval_hint}")
2434                elif isinstance(arg_sig, WorkspaceArg):
2435                    device = V.graph.scheduler.get_current_device_or_throw()
2436                    nbytes = V.graph.sizevars.size_hint(arg_sig.nbytes)
2437                    result.writeline(
2438                        f"{var_name} = torch.zeros({nbytes}, device='{device}', dtype=torch.uint8)"
2439                    )
2440                else:
2441                    raise KeyError(
2442                        f"Don't find the buffer or const tensor for {arg_name}"
2443                    )
2444                var_names.append(var_name)
2445            result.writeline(f"return {', '.join(var_names)},")
2446
2447        result.writelines(["\n", "\n", "def call(args):"])
2448        if grid is None:
2449            grid = []
2450            extra_args = []
2451            extra_args_str = None
2452            for tree in self.active_range_trees():
2453                expr = pexpr(V.graph.sizevars.size_hint(tree.numel))
2454                extra_args.append(expr)
2455                if tree.prefix != "r":
2456                    grid.append(expr)
2457            if self.need_numel_args():
2458                extra_args_str = ", ".join(map(str, extra_args)) + ", "
2459            else:
2460                extra_args_str = ""
2461            grid_arg = f"{extra_args_str}grid=grid({', '.join(grid)})"
2462        else:
2463            grid_arg = f"grid={grid}"
2464        current_device = V.graph.scheduler.get_current_device_or_throw()
2465        index = current_device.index
2466        with result.indent():
2467            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
2468            with result.indent():
2469                result.writeline(
2470                    V.graph.device_ops.set_device(index)
2471                )  # no-op to ensure context
2472                stream_name = f"stream{index}"
2473                result.writeline(f"{stream_name} = get_raw_stream({index})")
2474                result.writeline(
2475                    f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
2476                )
2477
2478        # benchmark all configs
2479        result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
2480        with result.indent():
2481            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
2482            with result.indent():
2483                result.writeline(
2484                    V.graph.device_ops.set_device(index)
2485                )  # no-op to ensure context
2486                result.writeline(
2487                    f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
2488                )
2489
2490        result.writelines(["\n", "\n", "if __name__ == '__main__':"])
2491        with result.indent():
2492            result.writeline(
2493                "from torch._inductor.runtime.benchmarking import benchmarker"
2494            )
2495            result.writeline("")
2496
2497            result.writeline("args = get_args()")
2498            result.writeline(
2499                "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)"
2500            )
2501            result.writeline(f"num_gb = {num_gb}")
2502            result.writeline("gb_per_s = num_gb / (ms / 1e3)")
2503            result.writeline(
2504                'print(f"{ms:.3f}ms    {num_gb:.3f}GB    {gb_per_s:.2f}GB/s")'
2505            )
2506
2507        return result
2508
2509    def imports_for_benchmark_kernel(self):
2510        return textwrap.dedent(
2511            """
2512            from torch._dynamo.testing import rand_strided
2513            {}
2514            import torch
2515            from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
2516        """.format(
2517                V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
2518            )
2519        )
2520
2521    def _get_heuristic(self):
2522        if self.persistent_reduction:
2523            assert self.inside_reduction
2524            return "persistent_reduction"
2525        elif self.inside_reduction:
2526            return "reduction"
2527        return "pointwise"
2528
2529    @staticmethod
2530    def inductor_meta_common():
2531        inductor_meta = {
2532            "backend_hash": torch.utils._triton.triton_hash_with_backend(),
2533            "are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(),
2534            "assert_indirect_indexing": config.assert_indirect_indexing,
2535            "autotune_local_cache": config.autotune_local_cache,
2536            "autotune_pointwise": config.triton.autotune_pointwise,
2537            "autotune_remote_cache": config.autotune_remote_cache,
2538            "force_disable_caches": config.force_disable_caches,
2539            "dynamic_scale_rblock": config.dynamic_scale_rblock,
2540            "max_autotune": config.max_autotune,
2541            "max_autotune_pointwise": config.max_autotune_pointwise,
2542            "min_split_scan_rblock": config.triton.min_split_scan_rblock,
2543            "spill_threshold": config.triton.spill_threshold,
2544            "store_cubin": config.triton.store_cubin,
2545        }
2546        if torch.version.hip is not None:
2547            inductor_meta["is_hip"] = True
2548        if config.is_fbcode():
2549            inductor_meta["is_fbcode"] = True
2550        if config.profile_bandwidth:
2551            inductor_meta["profile_bandwidth"] = config.profile_bandwidth
2552            inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex
2553            inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output
2554            inductor_meta[
2555                "profile_bandwidth_with_do_bench_using_profiling"
2556            ] = config.profile_bandwidth_with_do_bench_using_profiling
2557        if config.coordinate_descent_tuning:
2558            inductor_meta[
2559                "coordinate_descent_tuning"
2560            ] = config.coordinate_descent_tuning
2561            inductor_meta[
2562                "coordinate_descent_search_radius"
2563            ] = config.coordinate_descent_search_radius
2564            inductor_meta[
2565                "coordinate_descent_check_all_directions"
2566            ] = config.coordinate_descent_check_all_directions
2567        return inductor_meta
2568
2569    def codegen_kernel(self, name=None):
2570        code = IndentedBuffer()
2571
2572        size_hints = []
2573        for numel in self.numels:
2574            numel_hint = V.graph.sizevars.symbolic_hint(numel)
2575            if not isinstance(numel_hint, (int, sympy.Integer)):
2576                # This default heuristic hint was picked carefully: it is
2577                # large, to ensure that we don't shrink the block size (since
2578                # if you don't have many elements, it'd be wasteful to pick a
2579                # large block size).  Since we don't know how many elements we
2580                # might have, we should be OK with some inefficiency to make
2581                # sure we handle the large case well.  8192 is the largest
2582                # block size we support, so we pick that.
2583                #
2584                # If we have a better hint for unbacked SymInts (e.g., because
2585                # a user told us, or we are tracking upper bounds) we could
2586                # use that here.
2587                size_hint = 8192
2588            else:
2589                size_hint = next_power_of_2(int(numel_hint))
2590            size_hints.append(size_hint)
2591
2592        if not self.inside_reduction:
2593            size_hints.pop()
2594
2595        heuristics = self._get_heuristic()
2596
2597        if name is None:
2598            code.splice(gen_common_triton_imports())
2599
2600            if config.benchmark_kernel:
2601                code.splice(self.imports_for_benchmark_kernel())
2602
2603        argdefs, _, signature, _ = self.args.python_argdefs()
2604        # maps actual expression to SizeArg if it is in sizevars replacements
2605        for i, arg in enumerate(signature):
2606            if isinstance(arg, SizeArg):
2607                # mypy is unhappy about the sympy.Expr
2608                # type for the key of the dict below
2609                symbol = cast(sympy.Symbol, arg.expr)
2610                if symbol in V.graph.sizevars.inv_precomputed_replacements:
2611                    signature[i] = SizeArg(
2612                        arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol]
2613                    )
2614
2615        mutated_args: OrderedSet[str] = OrderedSet()
2616        for mutation in self.mutations:
2617            if mutation in self.args.input_buffers:
2618                mutated_args.add(self.args.input_buffers[mutation])
2619            if (
2620                mutation in self.args.inplace_buffers
2621                and mutation not in V.graph.removed_buffers
2622                and mutation not in self.removed_buffers
2623            ):
2624                mutated_args.add(self.args.inplace_buffers[mutation].inner_name)
2625            if mutation in self.args.output_buffers:
2626                mutated_args.add(self.args.output_buffers[mutation])
2627
2628        # workspace arguments are mutated, but are not marked as mutations in self.mutations
2629        # because their buffers are added during codegen, and aren't tracked during
2630        # lowering/scheduling. So we add them as mutated_args explicitly below.
2631        #
2632        # In the logic below, we only mark the workspaces a mutated if they are marked with
2633        # zero_fill: that's because, if we don't expect the buffer to be pre-filled with
2634        # zeros, then, although we still mutate the data, we don't care about those
2635        # mutations because we don't make any assumptions about the contents of the
2636        # workspace buffer.
2637        for argname, arg in zip(argdefs, signature):
2638            if isinstance(arg, WorkspaceArg) and arg.zero_fill:
2639                mutated_args.add(argname)
2640
2641        mutated_args = sorted(mutated_args)
2642
2643        triton_meta_signature = signature_to_meta(
2644            signature, size_dtype=self.index_dtype
2645        )
2646        triton_meta = {
2647            "signature": triton_meta_signature,
2648            "device": DeviceProperties.create(
2649                V.graph.scheduler.get_current_device_or_throw()
2650            ),
2651            "constants": {},
2652        }
2653
2654        inductor_meta = {
2655            "autotune_hints": set(self.autotune_hints),
2656            "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
2657            "mutated_arg_names": mutated_args,
2658            "no_x_dim": self.no_x_dim,
2659            "num_load": self.num_load,
2660            "num_reduction": self.num_reduction,
2661            **self.inductor_meta_common(),
2662        }
2663
2664        num_gb = None
2665        if config.benchmark_kernel or config.profile_bandwidth:
2666            num_gb = self.estimate_kernel_num_bytes() / 1e9
2667            inductor_meta["kernel_num_gb"] = num_gb
2668
2669        for tree in self.active_range_trees():
2670            sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
2671            signature.append(sizearg)
2672            triton_meta_signature[len(argdefs)] = signature_of(
2673                sizearg, size_dtype=self.index_dtype
2674            )
2675            argdefs.append(f"{tree.prefix}numel")
2676            # constexpr version causes issues, see
2677            # https://github.com/pytorch/torchdynamo/pull/1362
2678            # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint(
2679            #     tree.numel
2680            # )
2681            # argdefs.append(f"{tree.prefix}numel: tl.constexpr")
2682        triton_meta["configs"] = [config_of(signature)]
2683
2684        # Triton compiler includes equal_to_1 args into constants even
2685        # when they are not constexpr. otherwise there may be a segfault
2686        # during launching the Inductor-compiled Triton kernel.
2687        # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
2688        # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
2689        for arg_num in triton_meta["configs"][0].equal_to_1:  # type: ignore[index]
2690            triton_meta["constants"][arg_num] = 1  # type: ignore[index]
2691
2692        self.triton_meta = triton_meta
2693
2694        for tree in self.range_trees:
2695            if tree.prefix == "r" and self.persistent_reduction:
2696                # RBLOCK for persistent_reduction is defined in codegen_static_numels
2697                continue
2698            if tree.tensor_dim is None:
2699                continue
2700            argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr")
2701
2702        self.codegen_body()
2703
2704        for helper in self.helper_functions:
2705            code.writeline("")
2706            code.splice(helper)
2707
2708        if self.inside_reduction:
2709            reduction_hint = self.reduction_hint
2710            heuristics_line = f"""
2711                @triton_heuristics.{heuristics}(
2712                    size_hints={size_hints!r},
2713                    reduction_hint={reduction_hint},
2714                    filename=__file__,
2715                    triton_meta={triton_meta!r},
2716                    inductor_meta={inductor_meta!r}
2717                )
2718                @triton.jit
2719            """
2720        else:
2721            tile_hint = ""
2722            if len(size_hints) == 2:
2723                if len(signature) == 4:  # input, output and 2 args
2724                    tile_hint = "tile_hint=TileHint.SQUARE,"
2725                else:
2726                    tile_hint = "tile_hint=TileHint.DEFAULT,"
2727            heuristics_line = f"""
2728                @triton_heuristics.{heuristics}(
2729                    size_hints={size_hints!r}, {tile_hint}
2730                    filename=__file__,
2731                    triton_meta={triton_meta!r},
2732                    inductor_meta={inductor_meta!r},
2733                    min_elem_per_thread={self.min_elem_per_thread}
2734                )
2735                @triton.jit
2736            """
2737        code.splice(heuristics_line)
2738        code.writeline(
2739            f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
2740        )
2741        with code.indent():
2742            self.codegen_static_numels(code)
2743            for old, new in self.args.aliases():
2744                code.writeline(f"{old} = {new}")
2745            code.splice(self.body)
2746
2747        if config.benchmark_kernel:
2748            code.splice(self.codegen_kernel_benchmark(num_gb))
2749
2750        return code.getvalue()
2751
2752    def _get_persistent_RBLOCK(self, rnumel):
2753        rnumel = V.graph.sizevars.simplify(rnumel)
2754        if isinstance(rnumel, (sympy.Integer, int)):
2755            val = int(rnumel)
2756            val = next_power_of_2(val)
2757        else:
2758            val = 128
2759            while not V.graph.sizevars.statically_known_leq(rnumel, val):
2760                assert val <= 16 * 1024, f"Failed to find static RBLOCK for {rnumel}"
2761                val *= 2
2762        return val
2763
2764    def codegen_static_numels(self, code):
2765        """
2766        We get a small speedup from hard coding numels if they are static.
2767
2768        This code stomps on the passed-in values by writing an constant to the top of the kernel.
2769
2770        In a kernel like:
2771        def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
2772
2773        We would add
2774        xnumel = 4096
2775        rnumel = 768
2776
2777        After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
2778        a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
2779        knows that its a static numel, as that you just plop a constant into the kernel.
2780        """
2781        for tree in self.range_trees:
2782            if tree.prefix != "r" or self.inside_reduction:
2783                simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
2784                if isinstance(simplified_tree_numel, (sympy.Integer, int)):
2785                    code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
2786
2787            if tree.prefix == "r" and self.persistent_reduction:
2788                val = self._get_persistent_RBLOCK(tree.numel)
2789                code.writeline(f"RBLOCK: tl.constexpr = {val}")
2790
2791            if tree.prefix == "x" and self.no_x_dim:
2792                code.writeline("XBLOCK: tl.constexpr = 1")
2793
2794    def _get_grid_fn(self):
2795        return "grid"
2796
2797    def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid):
2798        # TODO(jansel): if there are constants, we shouldn't bother passing them as args
2799        for tree in self.range_trees:
2800            if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)):
2801                expr = tree.numel
2802            else:
2803                expr = V.graph.wrapper_code.generate_numel_expr(name, tree)
2804
2805            if tree.prefix != "r" or self.inside_reduction:
2806                call_args.append(expr)
2807                arg_types.append(type(expr))
2808            if tree.grid_dim is not None:
2809                grid.append(expr)
2810
2811    def call_kernel(self, name: str, node: Optional[IRNode] = None):
2812        wrapper = V.graph.wrapper_code
2813        wrapper.write_triton_header_once()
2814        _, call_args, _, arg_types = self.args.python_argdefs()
2815        grid: List[Any] = []
2816        self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid)
2817        current_device = V.graph.scheduler.get_current_device_or_throw()
2818
2819        if self.args.workspace_arg is not None:
2820            ws = self.args.workspace_arg
2821            wrapper.generate_workspace_allocation(
2822                ws.nbytes, current_device, ws.zero_fill
2823            )
2824
2825        grid = wrapper.generate_default_grid(name, grid)
2826        wrapper.generate_kernel_call(
2827            name,
2828            call_args,
2829            grid,
2830            current_device.index,
2831            cuda=True,
2832            triton=True,
2833            arg_types=arg_types,
2834            grid_fn=self._get_grid_fn(),
2835            triton_meta=self.triton_meta,
2836        )
2837
2838        if self.args.workspace_arg is not None:
2839            wrapper.writeline(wrapper.make_free_by_names(["workspace"]))
2840
2841    def codegen_nan_check(self):
2842        wrapper = V.graph.wrapper_code
2843        _, call_args, arg_signatures, _ = self.args.python_argdefs()
2844        for arg, arg_signature in zip(call_args, arg_signatures):
2845            if isinstance(arg_signature, TensorArg):
2846                if V.graph.cpp_wrapper:
2847                    if config.abi_compatible:
2848                        wrapper.writeline(
2849                            f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));'
2850                        )
2851                    else:
2852                        wrapper.writeline(f'assert_inf_and_nan("{arg}", {arg});')
2853                else:
2854                    line = f"assert not {arg}.isnan().any().item()"
2855                    wrapper.writeline(line)
2856                    line = f"assert not {arg}.isinf().any().item()"
2857                    wrapper.writeline(line)
2858
2859    def create_cse_var(self, *args, **kwargs):
2860        return TritonCSEVariable(*args, **kwargs)
2861
2862    def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry):
2863        line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}"
2864        if entry.root.is_loop:
2865            self.indexing_code.writeline(line)
2866        else:
2867            # lift non-reduction stores outside loop
2868            self.body.writeline(line)
2869
2870    def iteration_ranges_ranges_code(self, entry):
2871        assert entry.tensor_dim is not None
2872        size = self.indexing_size_str(entry.tensor_dim)
2873        index_dtype = self.index_dtype
2874        convert = f".to({index_dtype})" if index_dtype != "tl.int32" else ""
2875        return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{convert}"
2876
2877    def iteration_ranges_scalar_code(self, entry, value):
2878        index_dtype = self.index_dtype
2879        ndim = self.triton_tensor_ndim()
2880        size = [1] * ndim
2881        return f"tl.full({size}, {value}, {index_dtype})"
2882
2883    def iteration_ranges_get_pid(self, entry):
2884        assert entry.grid_dim is not None
2885        key = f"tl.program_id({entry.grid_dim})"
2886        # y_grid has a limit, so express it in terms of y and z in case of overflow.
2887        # z grid is only exercised when max_tiles == 3 (off by default).
2888        if (
2889            entry.grid_dim == 1
2890            and not entry.has_zdim
2891            and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid())
2892        ):
2893            # For ynumel larger than max_ygrid, we need to use zdim.
2894            # For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z).
2895            # So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset.
2896            key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))"
2897        pid = entry.pid_cache.get(key, key)
2898        if self.index_dtype != "tl.int32":
2899            return f"{pid}.to({self.index_dtype})"
2900        return pid
2901
2902    def _has_constant_mask(self, tree: IterationRangesRoot):
2903        if not self.optimize_mask:
2904            return False
2905        if V.graph.sizevars.statically_known_equals(tree.numel, 1):  # type: ignore[arg-type]
2906            return True
2907        # Masks are superfluous if numel is a multiple of BLOCK
2908        # (We use the fact that BLOCK is required by triton to be a power of 2)
2909        if tree.prefix == "r" and self.persistent_reduction:
2910            max_block = self._get_persistent_RBLOCK(tree.numel)
2911        elif tree.prefix == "x" and self.no_x_dim:
2912            max_block = 1
2913        else:
2914            if tree.prefix.upper() not in TRITON_MAX_BLOCK:
2915                return False
2916            max_block = TRITON_MAX_BLOCK[tree.prefix.upper()]
2917
2918        # Optional optimization: if block divides numel exactly, we will
2919        # never need to do a masked load to handle stragglers at the end.
2920        # It's faster to avoid masking at all.  But it is sound to always
2921        # mask.
2922        return V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block)
2923
2924    def filter_masks(self, mask_vars):
2925        for tree in self.range_trees:
2926            if self._has_constant_mask(tree):
2927                mask_vars.discard(f"{tree.prefix}mask")
2928
2929    def iteration_ranges_codegen_header(self, entry, code):
2930        x = entry.prefix
2931        if entry.is_loop:
2932            code.writeline(f"{entry.name} = {x}offset + {x}base")
2933        elif entry.grid_dim is None:
2934            # no need to "{x}offset = "
2935            code.writeline(f"{entry.name} = {self.iteration_ranges_ranges_code(entry)}")
2936            code.writeline(f"{x}offset = 0")
2937        else:
2938            if entry.tensor_dim is not None:
2939                line = f"{x}offset + {self.iteration_ranges_ranges_code(entry)}"
2940            else:
2941                line = self.iteration_ranges_scalar_code(entry, f"{x}offset")
2942            code.writelines(
2943                [
2944                    f"{x}offset = {self.iteration_ranges_get_pid(entry)} * {x.upper()}BLOCK",
2945                    f"{entry.name} = {line}",
2946                ]
2947            )
2948
2949        if self._has_constant_mask(entry):
2950            sizes = self.dense_size_str()
2951            code.writeline(f"{x}mask = tl.full({sizes}, True, tl.int1)")
2952        else:
2953            code.writeline(f"{x}mask = {entry.name} < {x}numel")
2954
2955
2956class TritonScheduling(SIMDScheduling):
2957    int32_type = "tl.int32"
2958    int64_type = "tl.int64"
2959    kernel_type = TritonKernel
2960    backend_features = dict.fromkeys(  # dict for deterministic order
2961        [
2962            BackendFeature.FOREACH,
2963            BackendFeature.BUCKETIZE,
2964            BackendFeature.INPLACE_BUFFERS,
2965            BackendFeature.MASKED_SCATTER_WITH_INDEX,
2966            BackendFeature.SCAN,
2967            BackendFeature.TRITON_TEMPLATES,
2968        ]
2969    )
2970    if torch.version.hip is None:
2971        backend_features.update(
2972            dict.fromkeys(
2973                [
2974                    # TODO: Move this above when ROCm triton adds support for multiple inputs
2975                    BackendFeature.TUPLE_REDUCTION,
2976                    BackendFeature.SORT,
2977                ]
2978            )
2979        )
2980
2981    @classmethod
2982    def get_backend_features(cls, device: torch.device):
2983        return cls.backend_features
2984
2985    def codegen_comment(self, node_schedule):
2986        wrapper = V.graph.wrapper_code
2987        origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
2988        if origins:
2989            wrapper.writeline(origins)
2990
2991        if config.debug_fusion:
2992            from torch._inductor.scheduler import (
2993                BaseSchedulerNode,
2994                ForeachKernelSchedulerNode,
2995            )
2996
2997            if not any(
2998                isinstance(n, ForeachKernelSchedulerNode) for n in node_schedule
2999            ):
3000                # We probably should look what are the nodes inside a foreach
3001                # schedule node
3002                node_names = [
3003                    n.get_name()
3004                    for n in node_schedule
3005                    if isinstance(n, BaseSchedulerNode)
3006                ]
3007                wrapper.writeline(
3008                    f"{wrapper.comment} Fused node name list: {', '.join(node_names)}"
3009                )
3010
3011    def define_kernel(self, src_code, node_schedule, kernel):
3012        wrapper = V.graph.wrapper_code
3013        if src_code in wrapper.src_to_kernel:
3014            kernel_name = wrapper.src_to_kernel[src_code]
3015        else:
3016            fused_name = (
3017                get_fused_kernel_name(node_schedule, config.triton.descriptive_names)
3018                if config.triton.descriptive_names
3019                else ""
3020            )
3021            kernel_category = get_kernel_category_by_source_code(src_code)[:3]
3022            kernel_name = "_".join(
3023                ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()]
3024            )
3025            # use the original src_code as the key
3026            wrapper.src_to_kernel[src_code] = kernel_name
3027            subs_name = kernel_name if config.triton.unique_kernel_names else "triton_"
3028
3029            # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name
3030            # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set
3031            # to "triton_" to maximize caching opportunities (when unique_kernel_names = False).
3032            src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
3033            src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name)
3034
3035            # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
3036            # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
3037            src_code = src_code.replace("#pragma CMT", "#")
3038
3039            basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py")
3040
3041            compile_wrapper = IndentedBuffer()
3042            compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''")
3043            compile_wrapper.splice(src_code, strip=True)
3044            current_device = V.graph.scheduler.get_current_device_or_throw()
3045            compile_wrapper.writeline(f"''', device_str='{current_device.type}')")
3046
3047            metadata_comment = f"# kernel path: {kernel_path}"
3048            origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
3049            metadata_comment += "\n" + origins + "\n" + detailed_origins
3050            wrapper.define_kernel(
3051                kernel_name, compile_wrapper.getvalue(), metadata_comment
3052            )
3053
3054            # log kernel metadata for offline analysis.
3055            # E.g. one can find all unaligned inner reduction and check if
3056            # padding helps with the perf kernel by kernel.
3057            if is_metric_table_enabled("kernel_metadata"):
3058                log_kernel_metadata(kernel_name, kernel_path, src_code)
3059
3060        return kernel_name
3061
3062    def benchmark_fused_nodes(self, nodes):
3063        with preserve_rng_state():
3064            src_code = self.generate_kernel_code_from_nodes(
3065                nodes, benchmark_kernel=True
3066            )
3067            mod = PyCodeCache.load(src_code)
3068
3069            def cache_file_path():
3070                assert mod.__file__ is not None
3071                return os.path.splitext(mod.__file__)[0] + ".kernel_perf"
3072
3073            def load_cache():
3074                path = cache_file_path()
3075                if os.path.exists(path):
3076                    with open(path) as fd:
3077                        return float(fd.read())
3078                return None
3079
3080            def store_cache():
3081                path = cache_file_path()
3082                with open(path, "w") as fd:
3083                    fd.write(str(ms))
3084
3085            log.debug(
3086                "kernel src code for %s written to: %s",
3087                {n.get_name() for n in nodes},
3088                mod.__file__,
3089            )
3090            ms = load_cache()
3091            if ms is not None:
3092                return ms, mod.__file__
3093
3094            args = mod.get_args()
3095            call = mod.call
3096            wrapped_jit_function = mod.triton_
3097
3098            # call once to trigger the compilation
3099            try:
3100                call(wrapped_jit_function.clone_args(*args)[0])
3101            except Exception as e:
3102                log.debug(
3103                    "Exception (%s) in compiling fused nodes %s",
3104                    e,
3105                    {n.get_name() for n in nodes},
3106                )
3107                ms = float("inf")
3108                store_cache()
3109                return ms, mod.__file__
3110
3111            launchers = wrapped_jit_function.launchers
3112            assert len(launchers) == 1
3113            if launchers[0].n_spills > 0:
3114                # skip benchmarking the kernel if there are register spills
3115                ms = float("inf")
3116            else:
3117                # We have to clone the inplace updated arguments to avoid earlier calls
3118                # generating out of range indices for later calls.
3119                ms = benchmarker.benchmark_gpu(
3120                    lambda: call(wrapped_jit_function.clone_args(*args)[0])
3121                )
3122
3123                # overhead of cloning args gives bias for fusing the kernel
3124                # in the case of mutating/in-placeable second fusion
3125                # TODO - would be better as a hook in triton do_bench that reset
3126                # the input values between benchmarking
3127                ms = ms - benchmarker.benchmark_gpu(
3128                    lambda: wrapped_jit_function.clone_args(*args)
3129                )
3130
3131            log.debug(
3132                "The fused kernel for %s took %.3f ms to run",
3133                {n.get_name() for n in nodes},
3134                ms,
3135            )
3136            store_cache()
3137            return ms, mod.__file__
3138
3139    def benchmark_combo_kernel(self, node_list):
3140        def cache_file_path():
3141            assert mod.__file__ is not None
3142            return os.path.splitext(mod.__file__)[0] + ".kernel_perf"
3143
3144        def load_cache():
3145            path = cache_file_path()
3146            if os.path.exists(path):
3147                with open(path) as fd:
3148                    return tuple(float(e) for e in fd.read().split())
3149            return (None, None)
3150
3151        def store_cache():
3152            path = cache_file_path()
3153            with open(path, "w") as fd:
3154                fd.write(str(ms) + " " + str(ms_clone))
3155
3156        total_ms, file_list = 0, []
3157        total_clone_ms = 0
3158        removed_buffers_orig = V.graph.removed_buffers
3159        V.graph.removed_buffers = OrderedSet(removed_buffers_orig)
3160        inplaced_to_remove_orig = V.graph.inplaced_to_remove
3161        V.graph.inplaced_to_remove = OrderedSet(inplaced_to_remove_orig)
3162        enable_autotune = config.combo_kernels_autotune > 0
3163        mixed_sizes = config.combo_kernel_allow_mixed_sizes > 0
3164        kernel_code_list = self.generate_combo_kernel_code(
3165            subkernel_nodes=node_list,
3166            custom_part_algorithm=True,
3167            enable_autotune=enable_autotune,
3168            mixed_sizes=mixed_sizes,
3169            only_gen_src_code=True,
3170        )
3171
3172        for src_code, _, node_group in kernel_code_list:
3173            fused_node_lists = [node.get_nodes() for node in node_group]
3174            names = [n.get_name() for nodes in fused_node_lists for n in nodes]
3175
3176            src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
3177            mod = PyCodeCache.load(src_code)
3178
3179            log.debug(
3180                "kernel src code for %s written to: %s",
3181                names,
3182                mod.__file__,
3183            )
3184            ms, ms_clone = load_cache()
3185            if ms is not None:
3186                total_ms += ms
3187                total_clone_ms += ms_clone
3188                file_list.append(mod.__file__)
3189                continue
3190
3191            args = mod.get_args()
3192            call = mod.call
3193            wrapped_jit_function = mod.triton_
3194
3195            # call once to trigger the compilation
3196            call(wrapped_jit_function.clone_args(*args)[0])
3197
3198            launchers = wrapped_jit_function.launchers
3199            assert len(launchers) == 1
3200            if launchers[0].n_spills > 0:
3201                # skip benchmarking the kernel if there are register spills
3202                ms = ms_clone = float("inf")
3203            else:
3204                # We have to clone the inplace updated arguments to avoid earlier calls
3205                # generating out of range indices for later calls.
3206                ms = benchmarker.benchmark_gpu(
3207                    lambda: call(wrapped_jit_function.clone_args(*args)[0])
3208                )
3209                ms_clone = benchmarker.benchmark_gpu(
3210                    lambda: wrapped_jit_function.clone_args(*args)[0]
3211                )
3212
3213            log.debug(
3214                "The fused kernel for %s took %.3f ms to run, %.3f ms to clone inputs",
3215                {n.get_name() for n in node_group},
3216                ms,
3217                ms_clone,
3218            )
3219            store_cache()
3220            total_ms += ms
3221            total_clone_ms += ms_clone
3222            file_list.append(mod.__file__)
3223        V.graph.removed_buffers = removed_buffers_orig
3224        V.graph.inplaced_to_remove = inplaced_to_remove_orig
3225        return total_ms, total_clone_ms, file_list
3226