xref: /aosp_15_r20/external/pytorch/torch/_inductor/runtime/triton_heuristics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import builtins
5import copy
6import functools
7import hashlib
8import inspect
9import logging
10import math
11import operator
12import os
13import os.path
14import re
15import sys
16import threading
17import time
18from typing import Any, Dict, List, Optional, Set, Tuple
19
20import torch
21
22from .autotune_cache import AutotuneCache
23from .benchmarking import benchmarker
24from .coordinate_descent_tuner import CoordescTuner
25from .hints import (
26    _NUM_THREADS_PER_WARP,
27    AutotuneHint,
28    DeviceProperties,
29    HeuristicType,
30    ReductionHint,
31    TileHint,
32    TRITON_MAX_BLOCK,
33)
34from .runtime_utils import (
35    cache_dir,
36    ceildiv,
37    conditional_product,
38    create_bandwidth_info_str,
39    dynamo_timed,
40    get_first_attr,
41    get_max_y_grid,
42    get_num_bytes,
43    next_power_of_2,
44    triton_config_to_hashable,
45    validate_triton_config,
46)
47
48
49try:
50    import triton
51except ImportError:
52    triton = None
53
54if triton is not None:
55    from triton import Config
56    from triton.compiler import CompiledKernel
57    from triton.runtime.autotuner import OutOfResources
58    from triton.runtime.jit import KernelInterface
59
60    try:
61        from triton.compiler.compiler import ASTSource
62    except ImportError:
63        ASTSource = None
64
65    try:
66        from triton.backends.compiler import GPUTarget
67    except ImportError:
68        GPUTarget = None
69else:
70    Config = object
71    KernelInterface = object
72    OutOfResources = object
73    ASTSource = None
74    GPUTarget = None
75
76try:
77    autograd_profiler = torch.autograd.profiler
78except AttributeError:  # Compile workers only have a mock version of torch
79
80    class autograd_profiler:  # type: ignore[no-redef]
81        _is_profiler_enabled = False
82
83
84log = logging.getLogger(__name__)
85
86
87def autotune_hints_to_configs(
88    hints: Set[AutotuneHint], size_hints, block_size: int
89) -> List[Config]:
90    """
91    AutotuneHints can be attached to the metadata of triton kernels for providing
92    suggestions about what to try for autotuning. One reason to do this is if there are
93    some configs that are only useful in specific scenarios, in which case we can avoid
94    wasting compile time on autotuning unless we know we are in one of those scenarios.
95
96    Based on those hints, this function will generate a list of additional autotuning
97    configs to try.
98    """
99    xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...]
100    configs = []
101
102    for hint in hints:
103        if hint == AutotuneHint.ELEMENTS_PER_WARP_32:
104            if len(size_hints) == 1:
105                xyz_options = ((block_size // 4, None, None),)
106            elif len(size_hints) == 2:
107                xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None))
108            elif len(size_hints) == 3:
109                xyz_options = (
110                    (block_size // 4, 1, 1),
111                    (1, block_size // 4, 1),
112                    (1, 1, block_size // 4),
113                )
114            for xyz in xyz_options:
115                configs.append(
116                    triton_config(
117                        size_hints,
118                        *xyz,
119                        num_elements_per_warp=32,
120                    )
121                )
122
123    return configs
124
125
126def disable_pointwise_autotuning(inductor_meta):
127    # Autotuning can give different benchmarking results from run to run, and
128    # therefore we disable autotuning when use_deterministic flag is on.
129    if inductor_meta.get("are_deterministic_algorithms_enabled"):
130        return True
131    return not inductor_meta.get("autotune_pointwise", True)
132
133
134def _dump_launch_params(args, kwargs, launcher, kernel_name):
135    call_args = []
136    call_kwargs = {}
137    for arg in args:
138        if isinstance(arg, (int, bool)):
139            call_args.append(str(arg))
140        else:
141            call_args.append("T")
142    for k, v in kwargs.items():
143        if isinstance(arg, (int, bool)):
144            call_kwargs[k] = v
145        else:
146            call_kwargs[k] = v
147    for k, v in launcher.config.kwargs.items():
148        call_kwargs[k] = v
149    call_kwargs["num_warps"] = launcher.config.num_warps
150    call_kwargs["num_stages"] = launcher.config.num_stages
151    args_str = ""
152    args_str += ", ".join(call_args)
153    for k, v in call_kwargs.items():
154        args_str += f", {k}={v}"
155
156    abs_path = os.path.abspath(sys.argv[0])
157    with open(f"{abs_path}.launch_params", "a") as f:
158        f.write(f"{kernel_name} | {args_str}\n")
159
160
161class CachingAutotuner(KernelInterface):
162    """
163    Simplified version of Triton autotuner that has no invalidation
164    key and caches the best config to disk to improve cold start times.
165    Unlike the main triton Autotuner, this version can precompile all
166    configs, and does not rely on the Triton JIT.
167    """
168
169    def __init__(
170        self,
171        fn,
172        triton_meta,  # passed directly to triton
173        configs,
174        save_cache_hook,
175        mutated_arg_names: List[str],  # see [Note: clone mutated buffers]
176        heuristic_type,
177        size_hints=None,
178        inductor_meta=None,  # metadata not relevant to triton
179        custom_kernel=False,  # whether the kernel is inductor-generated or custom
180        filename: Optional[str] = None,
181    ):
182        super().__init__()
183
184        assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
185        # makes sure there are no pre-hooks on any of the triton configs
186        for cfg in configs:
187            validate_triton_config(cfg)
188
189        self.fn = fn
190        self.device_props: DeviceProperties = triton_meta["device"]
191        self.triton_meta = {
192            **triton_meta,
193            "device": self.device_props.index,
194            "device_type": self.device_props.type,
195        }
196        self.inductor_meta = {} if inductor_meta is None else inductor_meta
197        self.save_cache_hook = save_cache_hook
198        self.mutated_arg_names = mutated_arg_names
199        self.configs = configs
200        self.heuristic_type = heuristic_type
201        self.custom_kernel = custom_kernel
202        self.cuda_kernel_saved = False
203        if log.isEnabledFor(logging.DEBUG):
204            log.debug(
205                "CachingAutotuner gets %d configs for %s",
206                len(self.configs),
207                self.fn.__name__,
208            )
209            for c in self.configs:
210                log.debug(c)
211
212        self.launchers = []  # type: ignore[var-annotated]
213        self.lock = threading.Lock()
214        if os.getenv("TRITON_CACHE_DIR") is None:
215            os.environ["TRITON_CACHE_DIR"] = os.path.join(
216                cache_dir(),
217                "triton",
218                str(self.triton_meta.get("device", 0)),
219            )
220        log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"])
221
222        self.size_hints = size_hints
223        self.coordesc_tuner = CoordescTuner(
224            is_mm=False,
225            name=self.fn.__name__,
226            size_hints=size_hints,
227            inductor_meta=self.inductor_meta,
228        )
229        self.filename = filename
230
231        self.precompile_time_taken_ns = 0
232        self.autotune_time_taken_ns = 0
233
234    def precompile(self, warm_cache_only=False):
235        with self.lock:
236            if self.launchers:
237                return
238            self.launchers = []
239            compiled_binaries = []
240            if not self.configs:
241                raise RuntimeError("No triton configs are available")
242            for c in self.configs:
243                try:
244                    compiled_binary, launcher = self._precompile_config(
245                        c, warm_cache_only
246                    )
247                except OutOfResources as e:
248                    if len(self.configs) == 1:
249                        # There are no valid Triton configs
250                        raise e
251                    # Skip the config if we run out of resource
252                    continue
253                self.launchers.append(launcher)
254                compiled_binaries.append(compiled_binary)
255
256            if len(self.launchers) == 0:
257                raise RuntimeError(
258                    "No valid triton configs. Report a fatal compilation error"
259                )
260
261            seen_configs = set(self.configs)
262
263            device_prop = self.device_props
264            if (
265                self.inductor_meta.get("dynamic_scale_rblock", True)
266                and self.heuristic_type == HeuristicType.REDUCTION
267                and self.size_hints is not None
268                # Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary.
269                and device_prop.type == "cuda"
270                and device_prop.major
271                and device_prop.major >= 8
272            ):
273                assert device_prop.regs_per_multiprocessor
274                assert device_prop.max_threads_per_multi_processor
275                assert device_prop.multi_processor_count
276                for triton_config, compiled_binary in zip(
277                    self.configs, compiled_binaries
278                ):
279                    assert len(self.size_hints) == 2
280                    xblock = triton_config.kwargs.get("XBLOCK", 1)
281                    rblock = triton_config.kwargs["RBLOCK"]
282                    total_block = (self.size_hints[0] + xblock - 1) // xblock
283                    nreg = getattr(compiled_binary, "n_regs", None)
284                    if nreg is None:
285                        continue
286
287                    # make sure rblock is not too small
288                    if rblock <= 64:
289                        continue
290
291                    # each SM of A100 has 65536 32-bit registers. To maximize
292                    # the theoretical occupancy, we need run 2048 threads on each
293                    # SM. So each thread should use no more than 65536 / 2048
294                    # = 32 registers. In cases where occupancy matters, and each
295                    # thread uses too many registers, reduce RBLOCK to reduce
296                    # the register usage.
297                    # For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
298                    # from PLBartForCausalLM, latency improve from
299                    # 7.795ms to 4.883ms.
300                    #
301                    if (
302                        nreg
303                        <= device_prop.regs_per_multiprocessor
304                        // device_prop.max_threads_per_multi_processor
305                    ):
306                        continue
307
308                    nreg_per_warp = nreg * 32
309                    nreg_per_block = nreg_per_warp * triton_config.num_warps
310
311                    # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)'
312                    # The formula below is a tighter upper bound since we have the assumption that
313                    #   nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor
314                    # due to the if condition above and:
315                    #   regs_per_multiprocessor / nreg_per_block
316                    #   = regs_per_multiprocessor / (nreg * 32 * num_warps)
317                    #   < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
318                    #   = max_threads_per_multi_processor / (32 * num_warps)
319                    # Using a tigher upper bound can reveal more optimization opportunities.
320                    max_blocks_per_sm = max(
321                        device_prop.regs_per_multiprocessor // nreg_per_block, 1
322                    )
323
324                    if (
325                        total_block
326                        <= max_blocks_per_sm * device_prop.multi_processor_count
327                    ):
328                        # no need to improve occupancy
329                        continue
330                    new_config = copy.deepcopy(triton_config)
331                    new_config.kwargs["RBLOCK"] = rblock // 2
332                    if new_config in seen_configs:
333                        continue
334                    seen_configs.add(new_config)
335                    log.debug(
336                        "Dynamically scale down RBLOCK from TritonConfig(%s) and get a new TritonConfig(%s)",
337                        triton_config,
338                        new_config,
339                    )
340                    self.launchers.append(
341                        self._precompile_config(new_config, warm_cache_only)[1]
342                    )
343            self.configs = None
344
345    def get_device_interface(self):
346        # this code cannot run in compile workers, because it imports from torch
347        from torch._dynamo.device_interface import get_interface_for_device
348
349        return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
350
351    def _precompile_config(self, cfg: Config, warm_cache_only: bool):
352        """Ahead of time compile a given autotuner config."""
353        compile_meta = copy.deepcopy(self.triton_meta)
354        for k, v in cfg.kwargs.items():
355            if self.device_props.type == "hip":
356                if k == "matrix_instr_nonkdim":
357                    compile_meta["matrix_instr_nonkdim"] = v
358                    continue
359                if k == "waves_per_eu":
360                    compile_meta["waves_per_eu"] = v
361                    continue
362            compile_meta["constants"][self.fn.arg_names.index(k)] = v
363        compile_meta["num_warps"] = cfg.num_warps
364        compile_meta["num_stages"] = cfg.num_stages
365        compile_meta["debug"] = self.inductor_meta.get(
366            "assert_indirect_indexing", True
367        ) and not self.inductor_meta.get("is_hip", False)
368
369        # device type will be "hip" rather than "cuda" here
370        compile_meta["device_type"] = self.device_props.type
371        compile_meta["cc"] = self.device_props.cc
372
373        if ASTSource:
374            compile_args = (
375                ASTSource(
376                    self.fn,
377                    compile_meta["signature"],
378                    compile_meta["constants"],
379                    compile_meta["configs"][0],
380                ),
381            )
382
383            cc_str = str(compile_meta["cc"])
384            if "gfx10" in cc_str or "gfx11" in cc_str:
385                rocm_warp_size = 32
386            else:
387                rocm_warp_size = 64
388
389            if GPUTarget:
390                target = GPUTarget(
391                    compile_meta["device_type"],
392                    compile_meta["cc"],
393                    rocm_warp_size if torch.version.hip else 32,
394                )
395            else:
396                target = (
397                    (compile_meta["device_type"], compile_meta["cc"])
398                    if not torch.version.hip
399                    else [
400                        compile_meta["device_type"],
401                        compile_meta["cc"],
402                        rocm_warp_size,
403                    ]
404                )
405
406            options = {
407                "num_warps": compile_meta["num_warps"],
408                "num_stages": compile_meta["num_stages"],
409                "debug": compile_meta["debug"],
410            }
411            if self.device_props.type == "hip":
412                if "waves_per_eu" in compile_meta:
413                    options["waves_per_eu"] = compile_meta["waves_per_eu"]
414                if "matrix_instr_nonkdim" in compile_meta:
415                    options["matrix_instr_nonkdim"] = compile_meta[
416                        "matrix_instr_nonkdim"
417                    ]
418            compile_kwargs = {
419                "target": target,
420                "options": options,
421            }
422        else:
423            compile_args = (self.fn,)
424            compile_kwargs = compile_meta
425
426        if warm_cache_only:
427            return (
428                triton.compile(*compile_args, **compile_kwargs),
429                None,
430            )
431
432        # importing from torch is safe now that precompile has returned
433        from torch._dynamo.device_interface import DeviceGuard
434
435        device_interface = self.get_device_interface()
436
437        # load binary to the correct device
438        with DeviceGuard(device_interface, compile_meta["device"]):  # type: ignore[attr-defined]
439            # need to initialize context
440            device_interface.synchronize(device_interface.current_device())
441
442            try:
443                binary = triton.compile(*compile_args, **compile_kwargs)
444            except Exception:
445                log.exception(
446                    "Triton compilation failed: %s\n%s\nmetadata: %s",
447                    self.inductor_meta.get("kernel_name", "triton_"),
448                    self.fn.src,
449                    compile_meta,
450                )
451                raise
452            binary._init_handles()
453
454        call_args = [
455            arg
456            for i, arg in enumerate(self.fn.arg_names)
457            if i not in self.fn.constexprs
458        ]
459        def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs]
460
461        binary_shared = (
462            binary.shared if hasattr(binary, "shared") else binary.metadata.shared
463        )
464
465        scope = {
466            "grid_meta": cfg.kwargs,
467            "bin": binary,
468            "launch_enter_hook": CompiledKernel.launch_enter_hook,
469            "launch_exit_hook": CompiledKernel.launch_exit_hook,
470            "metadata": binary.packed_metadata
471            if hasattr(binary, "packed_metadata")
472            else binary.metadata,
473            "shared": binary_shared,
474        }
475
476        scope["num_warps"] = (
477            binary.num_warps
478            if hasattr(binary, "num_warps")
479            else binary.metadata.num_warps
480        )
481
482        scope["cta_args"] = (
483            (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims"))
484            if hasattr(binary, "num_ctas")
485            else (
486                (binary.metadata.num_ctas, *binary.metadata.cluster_dims)
487                if hasattr(binary, "metadata")
488                else ()
489            )
490        )
491
492        scope["function"] = get_first_attr(binary, "function", "cu_function")
493
494        def get_launch_args_without_kernel_launch_metadata(
495            grid,
496            grid_0,
497            grid_1,
498            grid_2,
499            stream,
500            function,
501            metadata,
502            bin,
503            launch_enter_hook,
504            launch_exit_hook,
505            num_warps,
506            shared,
507            cta_args,
508            args,
509        ):
510            """
511            Construct launch args before CompiledKernel.launch_metadata is added.
512            """
513            return (
514                grid_0,
515                grid_1,
516                grid_2,
517                num_warps,
518                *cta_args,
519                shared,
520                stream,
521                function,
522                launch_enter_hook,
523                launch_exit_hook,
524                metadata,
525            )
526
527        # Getting the kernel launch args is extremely perf-sensitive.  Evaluating
528        # `bin.launch_metadata` is relatively expensive, and returns None unless a
529        # `launch_enter_hook` is installed.  So if we don't have that hook installed,
530        # we want to burn None in to the launch args with zero overhead.
531        # See https://github.com/pytorch/pytorch/issues/123597
532        if binary.launch_enter_hook:
533
534            def get_launch_args_with_kernel_launch_metadata(
535                grid,
536                grid_0,
537                grid_1,
538                grid_2,
539                stream,
540                function,
541                metadata,
542                bin,
543                launch_enter_hook,
544                launch_exit_hook,
545                num_warps,
546                shared,
547                cta_args,
548                args,
549            ):
550                """
551                Construct launch args after CompiledKernel.launch_metadata is added
552                by https://github.com/openai/triton/pull/3492 .
553                """
554                return (
555                    grid_0,
556                    grid_1,
557                    grid_2,
558                    stream,
559                    function,
560                    metadata,
561                    bin.launch_metadata(grid, stream, *args),
562                    launch_enter_hook,
563                    launch_exit_hook,
564                )
565
566        else:
567
568            def get_launch_args_with_kernel_launch_metadata(
569                grid,
570                grid_0,
571                grid_1,
572                grid_2,
573                stream,
574                function,
575                metadata,
576                bin,
577                launch_enter_hook,
578                launch_exit_hook,
579                num_warps,
580                shared,
581                cta_args,
582                args,
583            ):
584                """
585                Construct launch args after CompiledKernel.launch_metadata is added
586                by https://github.com/openai/triton/pull/3492 .
587                """
588                return (
589                    grid_0,
590                    grid_1,
591                    grid_2,
592                    stream,
593                    function,
594                    metadata,
595                    None,
596                    launch_enter_hook,
597                    launch_exit_hook,
598                )
599
600        scope["get_launch_args"] = (
601            get_launch_args_with_kernel_launch_metadata
602            if hasattr(binary, "launch_metadata")
603            else get_launch_args_without_kernel_launch_metadata
604        )
605
606        scope["runner"] = get_first_attr(binary, "run", "c_wrapper")
607
608        exec(
609            f"""
610            def launcher({', '.join(def_args)}, grid, stream):
611                if callable(grid):
612                    grid_0, grid_1, grid_2 = grid(grid_meta)
613                else:
614                    grid_0, grid_1, grid_2 = grid
615
616                args = {', '.join(call_args)},
617                launch_args = get_launch_args(
618                    grid, grid_0, grid_1, grid_2, stream, function,
619                    metadata, bin, launch_enter_hook, launch_exit_hook,
620                    num_warps, shared, cta_args, args
621                )
622                runner(*launch_args, *args)
623                return bin
624            """.lstrip(),
625            scope,
626        )
627
628        launcher = scope["launcher"]
629        launcher.config = cfg
630        launcher.n_regs = getattr(binary, "n_regs", None)
631        launcher.n_spills = getattr(binary, "n_spills", None)
632        launcher.shared = binary_shared
633        launcher.store_cubin = self.inductor_meta.get("store_cubin", False)
634        # store this global variable to avoid the high overhead of reading it when calling run
635        if launcher.store_cubin:
636            launcher.fn = self.fn
637            launcher.bin = binary
638
639        return binary, launcher
640
641    def bench(self, launcher, *args, grid, with_profiler=False, **kwargs):
642        """Measure the performance of a given launcher"""
643        # we don't skip configs wiht spilled registers when auto-tuning custom
644        # (user-written) Triton kernels, as (i) we don't have any knowledge or
645        # control over the kernel code; (ii) there is empirical evidence that
646        # for some (complicated) custom Triton kernels, a register-spilling
647        # config may yield the best latency.
648        if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
649            "spill_threshold", 16
650        ):
651            log.debug(
652                "Skip config %s because of register spilling: %d",
653                launcher.config,
654                launcher.n_spills,
655            )
656            return float("inf")
657
658        device_interface = self.get_device_interface()
659        stream = device_interface.get_raw_stream(device_interface.current_device())
660
661        def kernel_call():
662            cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
663            launcher(
664                *cloned_args,
665                **cloned_kwargs,
666                grid=grid,
667                stream=stream,
668            )
669
670        if with_profiler:
671            from torch._inductor.utils import do_bench_using_profiling
672
673            return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
674
675        return benchmarker.benchmark_gpu(kernel_call, rep=40, fast_flush=True)
676
677    def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
678        from ..compile_fx import clone_preserve_strides
679
680        # [Note: clone mutated buffers]
681        # clone inplace buffers to avoid autotune contaminating them if
682        # the kernel does in-place stores. avoid cloning other buffers because
683        # it leads to increase memory use
684        cloned_args = []
685        for i, arg in enumerate(args):
686            if self.fn.arg_names[i] in self.mutated_arg_names:
687                assert isinstance(arg, torch.Tensor)
688                cloned_args.append(clone_preserve_strides(arg))
689            else:
690                cloned_args.append(arg)
691
692        cloned_kwargs: Dict[str, Any] = {}
693        for name, arg in kwargs.items():
694            if name in self.mutated_arg_names:
695                assert isinstance(arg, torch.Tensor)
696                cloned_kwargs[name] = clone_preserve_strides(arg)
697            else:
698                cloned_kwargs[name] = arg
699
700        return cloned_args, cloned_kwargs
701
702    def benchmark_all_configs(self, *args, **kwargs):
703        with dynamo_timed("CachingAutotuner.benchmark_all_configs"):
704            timings = {
705                launcher: self.bench(launcher, *args, **kwargs)
706                for launcher in self.launchers
707            }
708
709            for k, v in timings.items():
710                self.coordesc_tuner.cache_benchmark_result(k.config, v)
711
712            if log.isEnabledFor(logging.DEBUG):
713                log.debug("Benchmark all input configs for %s, get:", self.fn.__name__)
714                for k, v in timings.items():
715                    log.debug(
716                        "%s: %f, nreg %d, nspill %d, #shared-mem %s",
717                        k.config,
718                        v,
719                        k.n_regs,
720                        k.n_spills,
721                        k.shared,
722                    )
723
724            return timings
725
726    def autotune_to_one_config(self, *args, **kwargs):
727        """Do the actual autotuning"""
728        start_time = time.time_ns()
729        timings = self.benchmark_all_configs(*args, **kwargs)
730        benchmark_time_taken_ns = time.time_ns() - start_time
731        self.launchers = [builtins.min(timings, key=timings.get)]
732        self.autotune_time_taken_ns = (
733            self.precompile_time_taken_ns + benchmark_time_taken_ns
734        )
735        if self.save_cache_hook:
736            self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns)
737
738    def save_gpu_kernel(self, grid, stream, launcher):
739        if callable(grid):
740            grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
741        else:
742            grid_x, grid_y, grid_z = grid
743
744        key = self.inductor_meta.get("kernel_name", None)  # unique kernel name
745        assert key is not None, "kernel_name can not be None"
746        params = {
747            "mangled_name": launcher.bin.metadata.name
748            if hasattr(launcher.bin.metadata, "name")
749            else launcher.bin.metadata["name"],
750            "grid_x": grid_x,
751            "grid_y": grid_y,
752            "grid_z": grid_z,
753            "x_block": launcher.config.kwargs.get("XBLOCK", 1),
754            "y_block": launcher.config.kwargs.get("YBLOCK", None),
755            "z_block": launcher.config.kwargs.get("ZBLOCK", None),
756            "num_warps": launcher.bin.num_warps
757            if hasattr(launcher.bin, "num_warps")
758            else launcher.bin.metadata.num_warps,
759            "shared_mem": launcher.bin.shared
760            if hasattr(launcher.bin, "shared")
761            else launcher.bin.metadata.shared,
762            "stream": stream,
763            # User defined triton kernels will have arbitrary kwarg names
764            "meta": launcher.config.kwargs,
765        }
766        from torch._inductor.codecache import CudaKernelParamCache
767
768        bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin")
769        binary = launcher.bin.asm[bin_type]
770        CudaKernelParamCache.set(key, params, binary, bin_type)
771
772        self.cuda_kernel_saved = True
773
774    def coordinate_descent_tuning(self, launcher, *args, **kwargs):
775        """
776        Coordinate descent tuning can be run with or without max-autotune.
777
778        The only difference between these two is the starting config for coordinate_descent tuning.
779        E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4
780        and max-autotune figure out C3 is the best.
781
782        Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1;
783        while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
784        """
785        if (
786            self.heuristic_type == HeuristicType.TEMPLATE
787            or self.heuristic_type == HeuristicType.USER_AUTOTUNE
788        ):
789            # skip triton template
790            return launcher
791
792        config2launcher = {launcher.config: launcher}
793
794        def benchmark_one_config(config):
795            with self.lock:
796                _, launcher = self._precompile_config(config, False)
797            config2launcher[config] = launcher
798
799            out = self.bench(launcher, *args, **kwargs)
800            log.debug(
801                "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
802                launcher.config,
803                out,
804                launcher.n_regs,
805                launcher.n_spills,
806                launcher.shared,
807            )
808            return out
809
810        assert not (
811            self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
812            and "RBLOCK" in launcher.config.kwargs
813        ), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK"
814        start_time = time.time_ns()
815        best_config = self.coordesc_tuner.autotune(
816            benchmark_one_config, launcher.config, None
817        )
818        coordesc_time_taken_ns = time.time_ns() - start_time
819        best_config.found_by_coordesc = True
820
821        if self.save_cache_hook:
822            self.save_cache_hook(
823                best_config,
824                self.autotune_time_taken_ns + coordesc_time_taken_ns,
825                found_by_coordesc=True,
826            )
827        return config2launcher.get(best_config)
828
829    def run(self, *args, grid, stream, **kwargs):
830        if len(self.launchers) != 1:
831            if len(self.launchers) == 0:
832                start_time = time.time_ns()
833                self.precompile()
834                self.precompile_time_taken_ns = time.time_ns() - start_time
835            if len(self.launchers) > 1:
836                self.autotune_to_one_config(*args, grid=grid, **kwargs)
837
838        if not getattr(
839            self.launchers[0].config, "found_by_coordesc", False
840        ) and self.inductor_meta.get("coordinate_descent_tuning", False):
841            self.launchers = [
842                self.coordinate_descent_tuning(
843                    self.launchers[0], *args, grid=grid, **kwargs
844                )
845            ]
846
847        (launcher,) = self.launchers
848        if launcher.store_cubin:
849            self.save_gpu_kernel(grid, stream, launcher)
850
851        if os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", 0) == "1":
852            _dump_launch_params(args, kwargs, launcher, self.fn.__name__)
853
854        # it is faster than entering and exiting a context manager, even if the context
855        # manager is a nullcontext.
856        if autograd_profiler._is_profiler_enabled:
857            # grid can be a tuple of ints or a string.
858            if isinstance(grid, tuple):
859                grid_info = str(grid)
860            else:
861                grid_info = getattr(grid, "grid_fn_str", "")
862            with torch._C._profiler._RecordFunctionFast(
863                self.inductor_meta.get("kernel_name", "triton kernel"),
864                args,
865                {
866                    "kernel_file": "" if self.filename is None else self.filename,
867                    "kernel_backend": "triton",
868                    "grid": grid_info,
869                    "stream": stream,
870                },
871            ):
872                return launcher(
873                    *args,
874                    **kwargs,
875                    grid=grid,
876                    stream=stream,
877                )
878        else:
879            return launcher(
880                *args,
881                **kwargs,
882                grid=grid,
883                stream=stream,
884            )
885
886
887def _find_names(obj):
888    import gc
889    import inspect
890
891    frame = inspect.currentframe()
892    while frame is not None:
893        frame.f_locals
894        frame = frame.f_back
895    obj_names = []
896    for referrer in gc.get_referrers(obj):
897        if isinstance(referrer, dict):
898            for k, v in referrer.items():
899                if v is obj:
900                    obj_names.append(k)
901    return obj_names
902
903
904collected_calls: List[Any] = []
905
906
907def start_graph():
908    collected_calls.clear()
909
910
911def end_graph(output_file):
912    if len(collected_calls) == 0:
913        return
914    overall_time = sum(call[0] for call in collected_calls)
915    overall_gb = sum(call[1] for call in collected_calls)
916    cur_file = inspect.stack()[1].filename
917    summary_str = (
918        f"SUMMARY ({cur_file})\n"
919        f"{overall_time:.2f}ms   \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s"
920    )
921    print(summary_str)
922    print()
923    if output_file is not None:
924        # sort perf numbers in descending order, i.e. placing the
925        # most runtime-heavy kernels at the top of the list
926        sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True)
927        try:
928            with open(output_file, "a") as file:
929                log.debug("Save profile bandwidth results to %s", output_file)
930                file.write("====================\n")
931                file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n")
932                for ms, num_gb, gb_per_s, kernel_name in sorted_calls:
933                    # also display the runtime percentage for each kernel
934                    percentage = f"{ms/overall_time*100:.2f}%"
935                    suffix = f" \t {percentage} \t {kernel_name}"
936                    bw_info_str = create_bandwidth_info_str(
937                        ms,
938                        num_gb,
939                        gb_per_s,
940                        suffix=suffix,
941                        color=False,
942                    )
943                    file.write(bw_info_str + "\n")
944                file.write(f"{summary_str}\n\n")
945        except Exception as e:
946            log.warning(
947                "failed to write profile bandwidth result into %s: %s",
948                output_file,
949                e,
950            )
951
952
953class DebugAutotuner(CachingAutotuner):
954    def __init__(self, *args, regex_filter="", with_profiler=False, **kwargs):
955        self.regex_filter = regex_filter
956        self.with_profiler = with_profiler
957        super().__init__(*args, **kwargs)
958        self.cached = None
959
960    def run(self, *args, grid, stream):
961        possible_names = _find_names(self)
962        kernel_name = f"{max(possible_names, key=len)}"
963        if not re.match(self.regex_filter, kernel_name):
964            return
965        super().run(*args, grid=grid, stream=stream)
966        (launcher,) = self.launchers
967
968        if self.cached is None:
969            ms = self.bench(
970                launcher, *args, grid=grid, with_profiler=self.with_profiler
971            )
972            num_in_out_ptrs = len(
973                [
974                    arg_name
975                    for arg_name in self.fn.arg_names
976                    if arg_name.startswith("in_out_ptr")
977                ]
978            )
979            num_gb = self.inductor_meta.get("kernel_num_gb", None)
980            if num_gb is None:
981                num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
982            gb_per_s = num_gb / (ms / 1e3)
983            self.cached = ms, num_gb, gb_per_s, kernel_name
984            collected_calls.append((ms, num_gb, gb_per_s, kernel_name))
985            print(
986                create_bandwidth_info_str(
987                    ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}"
988                )
989            )
990
991
992def hash_configs(configs: List[Config]):
993    """
994    Hash used to check for changes in configurations
995    """
996    hasher = hashlib.sha256()
997    for cfg in configs:
998        hasher.update(
999            f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode()
1000        )
1001    return hasher.hexdigest()
1002
1003
1004def cached_autotune(
1005    size_hints: Optional[List[int]],
1006    configs: List[Config],
1007    triton_meta,
1008    heuristic_type,
1009    filename=None,
1010    inductor_meta=None,
1011    custom_kernel=False,
1012):
1013    """
1014    A copy of triton.autotune that calls our subclass.  Our subclass
1015    has additional debugging, error handling, and on-disk caching.
1016    """
1017    configs = unique_configs(configs)
1018    assert len(configs) == 1 or filename
1019    inductor_meta = {} if inductor_meta is None else inductor_meta
1020
1021    disabled = inductor_meta.get("force_disable_caches", False)
1022
1023    # on disk caching logic and/or remote caching
1024    autotune_cache = None
1025    if (
1026        not disabled
1027        and filename is not None
1028        and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
1029    ):
1030        configs_hash = hash_configs(configs)
1031
1032        autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
1033        if autotune_cache:
1034            if best_config := autotune_cache.read_best(inductor_meta, configs):
1035                configs = [best_config]
1036
1037    else:
1038        if disabled:
1039            log.debug("autotune caching is disabled by config.force_disable_caches")
1040
1041    mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
1042
1043    def decorator(fn):
1044        # Remove XBLOCK from config if it's not a function argument.
1045        # This way, coordinate descent tuning will not try to tune it.
1046        #
1047        # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1.
1048        import inspect
1049
1050        if "XBLOCK" not in inspect.signature(fn.fn).parameters:
1051            for tconfig in configs:
1052                if "XBLOCK" in tconfig.kwargs:
1053                    assert tconfig.kwargs["XBLOCK"] == 1
1054                    tconfig.kwargs.pop("XBLOCK")
1055
1056        if inductor_meta.get("profile_bandwidth"):
1057            return DebugAutotuner(
1058                fn,
1059                triton_meta=triton_meta,
1060                inductor_meta=inductor_meta,
1061                regex_filter=inductor_meta["profile_bandwidth_regex"],
1062                with_profiler=inductor_meta[
1063                    "profile_bandwidth_with_do_bench_using_profiling"
1064                ],
1065                configs=configs,
1066                save_cache_hook=autotune_cache and autotune_cache.save,
1067                mutated_arg_names=mutated_arg_names,
1068                heuristic_type=heuristic_type,
1069                size_hints=size_hints,
1070                custom_kernel=custom_kernel,
1071                filename=filename,
1072            )
1073        return CachingAutotuner(
1074            fn,
1075            triton_meta=triton_meta,
1076            inductor_meta=inductor_meta,
1077            configs=configs,
1078            save_cache_hook=autotune_cache and autotune_cache.save,
1079            mutated_arg_names=mutated_arg_names,
1080            heuristic_type=heuristic_type,
1081            size_hints=size_hints,
1082            custom_kernel=custom_kernel,
1083            filename=filename,
1084        )
1085
1086    return decorator
1087
1088
1089def unique_configs(configs: List[Config]):
1090    """Remove duplicate configurations"""
1091    seen = set()
1092    pruned_configs = []
1093
1094    for cfg in configs:
1095        key = triton_config_to_hashable(cfg)
1096        if key not in seen:
1097            seen.add(key)
1098            pruned_configs.append(cfg)
1099    return pruned_configs
1100
1101
1102def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
1103    for numel, label in zip((xnumel, ynumel, znumel), "XYZ"):
1104        if numel is None:
1105            continue
1106        block = cfg[f"{label}BLOCK"]
1107        if numel == 1:
1108            assert block == 1, (
1109                f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1"
1110                f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})."
1111            )
1112        max_block = TRITON_MAX_BLOCK[label]
1113        max_block_str = f'config.triton.max_block["{label}"]'
1114        assert max_block % block == 0, (
1115            f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}"
1116            f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})."
1117        )
1118
1119
1120def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
1121    # On AMD GPU each warp has 64 lanes which is double the size on NV GPU,
1122    # therefore using half the number of warps here correspondingly.
1123    if torch.version.hip:
1124        max_num_warps = (max_num_warps + 1) // 2
1125        min_num_warps = (min_num_warps + 1) // 2
1126    # persistent reduction is register intensive
1127    if register_intensive:
1128        max_num_warps = max_num_warps // 2
1129    return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps))
1130
1131
1132def _check_max_grid_x(size_hints, x, num_warps):
1133    # Check if maxGridSize is exceeded - if so then must scale XBLOCK further
1134    max_grid_x = 2147483647
1135    warp_size = (
1136        64 if torch.version.hip else 32
1137    )  # TODO: query warp size once #129663 is merged
1138    num_blocks = (size_hints[0] + x - 1) // x
1139
1140    while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints[0]:
1141        x *= 2  # Scale up XBLOCK if grid exceeds limits
1142        num_blocks = num_blocks // 2
1143        if x >= max_grid_x:
1144            raise AssertionError(
1145                "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue"
1146            )
1147    return x, num_blocks
1148
1149
1150def triton_config(
1151    size_hints,
1152    x,
1153    y=None,
1154    z=None,
1155    num_stages=1,
1156    num_elements_per_warp=256,
1157    min_elem_per_thread=0,
1158) -> Config:
1159    """
1160    Construct a pointwise triton config with some adjustment heuristics
1161    based on size_hints. Size_hints is a tuple of numels in each tile
1162    dimension and will be rounded up to the nearest power of 2.
1163
1164    num_elements_per_warp is a suggestion for controlling how many warps
1165    the triton config should contain. e.g.: if x=16, y=8, z=4 then
1166    num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128,
1167    we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's
1168    just a suggestion, and sometimes other adjustment heuristics will
1169    override the num_elements_per_warp.
1170
1171    min_elem_per_thread controls the minimum number of elements
1172    processed by each thread. It's always enforced.
1173    """
1174    # Ideally we want to read this from some device config
1175
1176    # for a 2d size_hints [a, b], a should be mapped to YBLOCK rather than XBLOCK
1177    size_hints = list(reversed(size_hints))
1178
1179    maxGridSize = [2147483647, 65535, 65535]
1180
1181    target = conditional_product(x, y, z)
1182    if conditional_product(*size_hints) < target:
1183        target //= 8
1184
1185    # shrink sizes to size hints
1186    x = min(x, size_hints[0])
1187    if y:
1188        y = min(y, size_hints[1])
1189    if z:
1190        z = min(z, size_hints[2])
1191
1192    # if we are below original block size, scale up where we can;
1193    # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension
1194    while x < min(size_hints[0], TRITON_MAX_BLOCK["X"]) and (
1195        x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target
1196    ):
1197        x *= 2
1198    while (
1199        y
1200        and y < min(size_hints[1], TRITON_MAX_BLOCK["Y"])
1201        and (
1202            y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target
1203        )
1204    ):
1205        y *= 2
1206    while (
1207        z
1208        and z < min(size_hints[2], TRITON_MAX_BLOCK["Z"])
1209        and (
1210            z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target
1211        )
1212    ):
1213        z *= 2
1214
1215    num_warps = _num_warps(
1216        conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
1217    )
1218    # we are going to arrive at 2 warps only if bs was too small due to
1219    # numel being too small. However to workaround some ptx bugs we still
1220    # want at least 4 warps if there's enough elements per thread
1221    # given that this is a rare situation, don't expect this to affect perf
1222    # in general
1223    # see https://github.com/pytorch/pytorch/pull/97950
1224    if conditional_product(x, y, z) >= 128 and not torch.version.hip:
1225        num_warps = max(num_warps, 4)
1226    xnumel = size_hints[0]
1227    ynumel = size_hints[1] if y else None
1228    znumel = size_hints[2] if z else None
1229
1230    # Increase x to satisfy min_elem_per_thread requirements.
1231    block_size = max(
1232        conditional_product(x, y, z),
1233        min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps,
1234    )
1235    x *= math.ceil(block_size / conditional_product(x, y, z))
1236
1237    x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
1238
1239    cfg = {"XBLOCK": x}
1240    if y:
1241        cfg["YBLOCK"] = y
1242    if z:
1243        cfg["ZBLOCK"] = z
1244    assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}"
1245    check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
1246    return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1247
1248
1249def triton_config_reduction(
1250    size_hints, x, r, num_stages=1, num_warps=None, register_intensive=False
1251) -> Config:
1252    """
1253    Construct a reduction triton config with some adjustment heuristics
1254    based on size_hints. Size_hints is a tuple of numels in each tile
1255    dimension and will be rounded up to the nearest power of 2.
1256    """
1257
1258    target = conditional_product(x, r)
1259    if conditional_product(*size_hints) < target:
1260        target //= 8
1261
1262    # shrink sizes to size hints
1263    x = min(x, size_hints[0])
1264    r = min(r, size_hints[1])
1265
1266    # if we are below original block size, scale up where we can
1267    while x < size_hints[0] and conditional_product(x, r) < target:
1268        x *= 2
1269    while r < size_hints[1] and conditional_product(x, r) < target:
1270        r *= 2
1271
1272    if num_warps is None:
1273        num_warps = conditional_product(x, r) // 128
1274    num_warps = _num_warps(
1275        num_warps, max_num_warps=16, register_intensive=register_intensive
1276    )
1277
1278    x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
1279
1280    while conditional_product(x, r) > target:
1281        if r == 1:
1282            break
1283        r = r // 2
1284
1285    cfg = {"XBLOCK": x, "RBLOCK": r}
1286    check_config(cfg, xnumel=size_hints[0])
1287    assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}"
1288    assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}"
1289    return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1290
1291
1292def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1):
1293    """
1294    Construct a tile reduction triton config with some adjustment
1295    heuristics based on size_hints. Size_hints is a tuple of numels in
1296    each tile dimension and will be rounded up to the nearest power of 2.
1297    """
1298
1299    target = conditional_product(x, y, r)
1300    if conditional_product(*size_hints) < target:
1301        target //= 8
1302
1303    # shrink sizes to size hints
1304    x = min(x, size_hints[0])
1305    y = min(y, size_hints[1])
1306    r = min(r, size_hints[2])
1307
1308    # if we are below original block size, scale up where we can
1309    while x < size_hints[0] and conditional_product(x, y, r) < target:
1310        x *= 2
1311    while r < size_hints[2] and conditional_product(x, y, r) < target:
1312        r *= 2
1313    while y < size_hints[1] and conditional_product(x, y, r) < target:
1314        y *= 2
1315
1316    cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
1317    num_warps = _num_warps(conditional_product(x, y, r) // 256, min_num_warps=1)
1318    check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1])
1319    assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}"
1320    return Config(cfg, num_warps=num_warps, num_stages=num_stages)
1321
1322
1323def pointwise(
1324    size_hints,
1325    triton_meta,
1326    tile_hint=None,
1327    filename=None,
1328    min_elem_per_thread=0,
1329    inductor_meta=None,
1330):
1331    """
1332    Construct @triton.heuristics() based on size_hints.
1333    """
1334    inductor_meta = {} if inductor_meta is None else inductor_meta
1335    assert not inductor_meta.get("no_x_dim")
1336
1337    numel = functools.reduce(operator.mul, size_hints)
1338    bs = max(256, min(numel // 128, 1024))
1339
1340    hinted_configs = autotune_hints_to_configs(
1341        inductor_meta.get("autotune_hints", set()), size_hints, bs
1342    )
1343
1344    triton_config_with_settings = functools.partial(
1345        triton_config, min_elem_per_thread=min_elem_per_thread
1346    )
1347
1348    if len(size_hints) == 1:
1349        if disable_pointwise_autotuning(inductor_meta) and not (
1350            inductor_meta.get("max_autotune")
1351            or inductor_meta.get("max_autotune_pointwise")
1352        ):
1353            return cached_autotune(
1354                size_hints,
1355                [triton_config_with_settings(size_hints, bs)],
1356                triton_meta=triton_meta,
1357                inductor_meta=inductor_meta,
1358                heuristic_type=HeuristicType.POINTWISE,
1359                filename=filename,
1360            )
1361        else:
1362            return cached_autotune(
1363                size_hints,
1364                [
1365                    triton_config_with_settings(
1366                        size_hints, bs, num_elements_per_warp=256
1367                    ),
1368                    triton_config_with_settings(
1369                        size_hints, bs // 2, num_elements_per_warp=64
1370                    ),
1371                    *hinted_configs,
1372                ],
1373                triton_meta=triton_meta,
1374                inductor_meta=inductor_meta,
1375                heuristic_type=HeuristicType.POINTWISE,
1376                filename=filename,
1377            )
1378    if len(size_hints) == 2:
1379        if (
1380            disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE
1381        ) and not (
1382            inductor_meta.get("max_autotune")
1383            or inductor_meta.get("max_autotune_pointwise")
1384        ):
1385            return cached_autotune(
1386                size_hints,
1387                [triton_config_with_settings(size_hints, 32, 32)],
1388                triton_meta=triton_meta,
1389                inductor_meta=inductor_meta,
1390                heuristic_type=HeuristicType.POINTWISE,
1391                filename=filename,
1392            )
1393        return cached_autotune(
1394            size_hints,
1395            [
1396                triton_config_with_settings(size_hints, 32, 32),
1397                triton_config_with_settings(size_hints, 64, 64),  # ~8% better for fp16
1398                triton_config_with_settings(size_hints, 256, 16),
1399                triton_config_with_settings(size_hints, 16, 256),
1400                triton_config_with_settings(size_hints, bs, 1),
1401                triton_config_with_settings(size_hints, 1, bs),
1402                *hinted_configs,
1403            ],
1404            triton_meta=triton_meta,
1405            inductor_meta=inductor_meta,
1406            filename=filename,
1407            heuristic_type=HeuristicType.POINTWISE,
1408        )
1409    if len(size_hints) == 3:
1410        if disable_pointwise_autotuning(inductor_meta):
1411            return cached_autotune(
1412                size_hints,
1413                [triton_config_with_settings(size_hints, 16, 16, 16)],
1414                triton_meta=triton_meta,
1415                inductor_meta=inductor_meta,
1416                heuristic_type=HeuristicType.POINTWISE,
1417                filename=filename,
1418            )
1419        return cached_autotune(
1420            size_hints,
1421            [
1422                triton_config_with_settings(size_hints, 16, 16, 16),
1423                triton_config_with_settings(size_hints, 64, 8, 8),
1424                triton_config_with_settings(size_hints, 8, 64, 8),
1425                triton_config_with_settings(size_hints, 8, 8, 64),
1426                triton_config_with_settings(size_hints, bs, 1, 1),
1427                triton_config_with_settings(size_hints, 1, bs, 1),
1428                triton_config_with_settings(size_hints, 1, 1, bs),
1429                *hinted_configs,
1430            ],
1431            triton_meta=triton_meta,
1432            inductor_meta=inductor_meta,
1433            filename=filename,
1434            heuristic_type=HeuristicType.POINTWISE,
1435        )
1436    raise NotImplementedError(f"size_hints: {size_hints}")
1437
1438
1439def _reduction_configs(
1440    *, size_hints: List[int], inductor_meta: Dict[str, Any]
1441) -> List[Config]:
1442    reduction_hint = inductor_meta.get("reduction_hint", None)
1443    assert len(size_hints) == 2
1444    rnumel = size_hints[-1]
1445
1446    register_intensive = False
1447    MAX_RBLOCK = 2048
1448    if (
1449        size_hints[0] >= 1024
1450        and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0)
1451        >= 10
1452    ):
1453        # A heuristics to reduce RBLOCK if a kernel potentially need many registers.
1454        # Consider load and reduction since load need move data into registers and
1455        # reduction needs an accumulator.
1456        #
1457        # The magic numbers are a bit arbitrary.
1458        #
1459        # We cannot rely on dynamically scaling down RBLOCK later, since sometimes
1460        # triton makes it to use less registers with worse perf. Check:
1461        # https://github.com/pytorch/pytorch/issues/126463
1462        #
1463        # The heuristic is a very simple one since registers can be reused. But
1464        # hopefully it can be a good enough indicator.
1465        MAX_RBLOCK = 1024
1466        register_intensive = True
1467
1468    contiguous_config = triton_config_reduction(
1469        size_hints,
1470        1,
1471        (rnumel if 256 <= rnumel < MAX_RBLOCK else MAX_RBLOCK),
1472        register_intensive=register_intensive,
1473    )
1474    outer_config = triton_config_reduction(
1475        size_hints, 64, 8, register_intensive=register_intensive
1476    )
1477    tiny_config = triton_config_reduction(
1478        size_hints,
1479        2 * (256 // rnumel) if rnumel <= 256 else 1,
1480        min(rnumel, MAX_RBLOCK),
1481        register_intensive=register_intensive,
1482    )
1483    if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"):
1484        pass  # skip all these cases
1485    elif reduction_hint == ReductionHint.INNER:
1486        return [contiguous_config]
1487    elif reduction_hint == ReductionHint.OUTER:
1488        return [outer_config]
1489    elif reduction_hint == ReductionHint.OUTER_TINY:
1490        return [tiny_config]
1491    if disable_pointwise_autotuning(inductor_meta):
1492        return [triton_config_reduction(size_hints, 32, 128)]
1493    return [
1494        contiguous_config,
1495        outer_config,
1496        tiny_config,
1497        triton_config_reduction(size_hints, 64, 64),
1498        triton_config_reduction(size_hints, 8, 512),
1499        # halve the XBLOCK/RBLOCK compared to outer_config
1500        # TODO: this may only be beneficial when each iteration of the reduction
1501        # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
1502        triton_config_reduction(size_hints, 64, 4, num_warps=8),
1503    ]
1504
1505
1506def reduction(
1507    size_hints,
1508    reduction_hint=False,
1509    triton_meta=None,
1510    filename=None,
1511    inductor_meta=None,
1512):
1513    """args to @triton.heuristics()"""
1514    inductor_meta = {} if inductor_meta is None else inductor_meta
1515    inductor_meta["reduction_hint"] = reduction_hint
1516    if inductor_meta.get("no_x_dim"):
1517        size_hints = [1, *size_hints[1:]]
1518
1519    assert triton_meta is not None
1520    rnumel = size_hints[-1]
1521    if len(size_hints) != 2:
1522        raise NotImplementedError(f"size_hints: {size_hints}")
1523
1524    configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
1525    return cached_autotune(
1526        size_hints,
1527        configs=configs,
1528        triton_meta=triton_meta,
1529        inductor_meta=inductor_meta,
1530        heuristic_type=HeuristicType.REDUCTION,
1531        filename=filename,
1532    )
1533
1534
1535def persistent_reduction(
1536    size_hints,
1537    reduction_hint=False,
1538    triton_meta=None,
1539    filename=None,
1540    inductor_meta=None,
1541):
1542    inductor_meta = {} if inductor_meta is None else inductor_meta
1543    inductor_meta["reduction_hint"] = reduction_hint
1544    if inductor_meta.get("no_x_dim"):
1545        size_hints = [1, *size_hints[1:]]
1546
1547    xnumel, rnumel = size_hints
1548
1549    configs = [
1550        triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
1551        for xblock in (1, 8, 32, 128)
1552        if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel)
1553    ]
1554
1555    # TODO(jansel): we should be able to improve these heuristics
1556    if reduction_hint == ReductionHint.INNER and rnumel >= 256:
1557        configs = configs[:1]
1558    elif reduction_hint == ReductionHint.OUTER:
1559        configs = configs[-1:]
1560    elif reduction_hint == ReductionHint.OUTER_TINY:
1561        configs = [
1562            triton_config_reduction(
1563                size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
1564            )
1565        ]
1566    for c in configs:
1567        # we don't need RBLOCK for persistent reduction
1568        c.kwargs.pop("RBLOCK")
1569
1570    if disable_pointwise_autotuning(inductor_meta):
1571        configs = configs[:1]
1572
1573    return cached_autotune(
1574        size_hints,
1575        configs,
1576        triton_meta=triton_meta,
1577        inductor_meta=inductor_meta,
1578        filename=filename,
1579        heuristic_type=HeuristicType.PERSISTENT_REDUCTION,
1580    )
1581
1582
1583def split_scan(
1584    size_hints,
1585    reduction_hint=False,
1586    triton_meta=None,
1587    filename=None,
1588    inductor_meta=None,
1589):
1590    """Heuristic for TritonSplitScanKernel"""
1591    inductor_meta = {} if inductor_meta is None else inductor_meta
1592    inductor_meta["reduction_hint"] = reduction_hint
1593    if inductor_meta.get("no_x_dim"):
1594        size_hints = [1, *size_hints[1:]]
1595
1596    assert triton_meta is not None
1597    if len(size_hints) != 2:
1598        raise NotImplementedError(f"size_hints: {size_hints}")
1599
1600    configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
1601
1602    # Fixup configs to enforce the minimum RBLOCK size
1603    min_rblock = inductor_meta.get("min_split_scan_rblock", 256)
1604    for cfg in configs:
1605        if cfg.kwargs["RBLOCK"] < min_rblock:
1606            cfg.kwargs["RBLOCK"] = min_rblock
1607
1608    return cached_autotune(
1609        size_hints,
1610        configs=configs,
1611        triton_meta=triton_meta,
1612        inductor_meta=inductor_meta,
1613        heuristic_type=HeuristicType.SPLIT_SCAN,
1614        filename=filename,
1615    )
1616
1617
1618def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None):
1619    """
1620    Compile a triton template
1621    """
1622    return cached_autotune(
1623        None,
1624        [triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
1625        triton_meta=triton_meta,
1626        inductor_meta=inductor_meta,
1627        heuristic_type=HeuristicType.TEMPLATE,
1628        filename=filename,
1629    )
1630
1631
1632def user_autotune(
1633    configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False
1634):
1635    """
1636    Compile a user defined triton kernel
1637    """
1638    defaults = inspect.signature(triton.Config).parameters
1639    default_num_stages = defaults["num_stages"].default
1640    default_num_warps = defaults["num_warps"].default
1641
1642    if len(configs) == 0:
1643        configs = [
1644            triton.Config(
1645                {}, num_stages=default_num_stages, num_warps=default_num_warps
1646            )
1647        ]
1648    else:
1649        configs = [
1650            triton.Config(
1651                c.get("kwargs", {}),
1652                num_stages=c.get("num_stages", default_num_stages),
1653                num_warps=c.get("num_warps", default_num_warps),
1654            )
1655            for c in configs
1656        ]
1657
1658    return cached_autotune(
1659        None,
1660        configs,
1661        triton_meta=triton_meta,
1662        heuristic_type=HeuristicType.USER_AUTOTUNE,
1663        filename=filename,
1664        inductor_meta=inductor_meta,
1665        custom_kernel=custom_kernel,
1666    )
1667
1668
1669def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
1670    """
1671    Compile a triton foreach kernel
1672    """
1673    return cached_autotune(
1674        None,
1675        [triton.Config({}, num_stages=1, num_warps=num_warps)],
1676        triton_meta=triton_meta,
1677        inductor_meta=inductor_meta,
1678        heuristic_type=HeuristicType.TEMPLATE,
1679        filename=filename,
1680    )
1681
1682
1683def grid(*numels):
1684    """Helper function to compute triton grids"""
1685    if len(numels) == 1:
1686        xnumel, ynumel, znumel = numels[0], None, None
1687    elif len(numels) == 2:
1688        xnumel, ynumel, znumel = numels[1], numels[0], None
1689    elif len(numels) == 3:
1690        xnumel, ynumel, znumel = numels[2], numels[1], numels[0]
1691    else:
1692        raise AssertionError(f"invalid size for numels {len(numels)}")
1693
1694    def get_grid_dim(numel, block):
1695        if numel is None:
1696            return 1
1697        if block is None:
1698            return numel
1699        return ceildiv(numel, block)
1700
1701    def grid_fn(meta):
1702        x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1))
1703        y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None))
1704
1705        max_y_grid = get_max_y_grid()
1706        if znumel is None:
1707            div = ceildiv(y_grid, max_y_grid)
1708            y_grid = ceildiv(y_grid, div)
1709            z_grid = div
1710        else:
1711            z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None))
1712            torch._check(
1713                y_grid <= max_y_grid,
1714                lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue",
1715            )
1716
1717        return (
1718            x_grid,
1719            y_grid,
1720            z_grid,
1721        )
1722
1723    setattr(grid_fn, "grid_fn_str", f"grid{numels}")  # noqa: B010
1724
1725    return grid_fn
1726
1727
1728def split_scan_grid(xnumel, rnumel):
1729    def grid_fn(meta):
1730        assert meta.get("XBLOCK", 1) == 1
1731        return (ceildiv(rnumel, meta.get("RBLOCK", 1)), xnumel, 1)
1732
1733    grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
1734    setattr(grid_fn, "grid_fn_str", grid_fn_str)  # noqa: B010
1735
1736    return grid_fn
1737
1738
1739def grid_combo_kernels(
1740    *numels, num_kernels, min_blocks, is_sequential, default_meta=None
1741):
1742    """min_blocks is the minimal size of the grid x dimension"""
1743    if not is_sequential:
1744        # round robin dispatch
1745        numels_agg = list(numels)
1746        for i in range(len(numels_agg)):
1747            if isinstance(numels_agg[i], (list, tuple)):
1748                numels_agg[i] = max(max(numels_agg[i]), 0)  # noqa: PLW3301
1749        kernel_grid_fn = grid(*numels_agg)
1750
1751        if isinstance(numels[-1], (list, tuple)):
1752            min_blocks_d = max(-min(numels[-1]), 0) * num_kernels
1753        else:
1754            min_blocks_d = None
1755        if min_blocks is None:
1756            assert min_blocks_d is not None
1757            min_blocks = min_blocks_d
1758        else:
1759            assert (
1760                min_blocks_d is None or min_blocks == min_blocks_d
1761            ), f"inconsistent min_blocks {min_blocks} vs  x grid {numels[-1]}"
1762    else:
1763        # sequential dispatch
1764        seq_numels = list(numels)
1765        # x numels are not used here, just a place holder
1766        seq_numels[-1] = 1024
1767        for i in range(len(seq_numels) - 1):
1768            if isinstance(seq_numels[i], (list, tuple)):
1769                seq_numels[i] = max(seq_numels[i])
1770
1771        kernel_grid_fn = grid(*seq_numels)
1772
1773    def get_grid_dim(numel, block):
1774        if numel is None:
1775            return 1
1776        if block is None:
1777            return numel
1778        return ceildiv(numel, block)
1779
1780    def grid_fn(meta):
1781        assert min_blocks is not None, "min_blocks must be a number"
1782        cuda_grid = list(kernel_grid_fn(meta))
1783        cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks)
1784        return tuple(cuda_grid)
1785
1786    def seq_grid_fn(meta):
1787        cuda_grid = list(kernel_grid_fn(meta))
1788        # x <= 0 means this kernel's x grid is not tunable (x_no_dim is true)
1789        x_grid = sum(
1790            [
1791                -x if x <= 0 else get_grid_dim(x, meta.get("XBLOCK", 1))
1792                for x in numels[-1]
1793            ]
1794        )
1795        cuda_grid[0] = x_grid
1796        return tuple(cuda_grid)
1797
1798    def grid_fn_default_meta(meta):
1799        return grid_fn(default_meta)
1800
1801    def seq_grid_fn_default_meta(meta):
1802        return seq_grid_fn(default_meta)
1803
1804    if default_meta is None:
1805        return grid_fn if not is_sequential else seq_grid_fn
1806    else:
1807        return grid_fn_default_meta if not is_sequential else seq_grid_fn_default_meta
1808