xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/triton_combo_kernel.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import itertools
2import logging
3import textwrap
4from collections import defaultdict
5from dataclasses import dataclass
6from typing import (
7    Any,
8    Callable,
9    cast,
10    Dict,
11    Iterable,
12    List,
13    Optional,
14    Tuple,
15    Type,
16    Union,
17)
18
19from sympy import Integer, Symbol
20
21from torch.utils._ordered_set import OrderedSet
22
23from .. import config, metrics
24from ..runtime.hints import DeviceProperties, ReductionHint
25from ..runtime.runtime_utils import next_power_of_2
26from ..runtime.triton_heuristics import grid_combo_kernels
27from ..scheduler import BaseSchedulerNode
28from ..utils import Placeholder
29from ..virtualized import V
30from .common import (
31    DeferredLine,
32    IndentedBuffer,
33    Kernel,
34    PythonPrinter,
35    SizeArg,
36    WorkspaceArg,
37)
38from .simd import SIMDScheduling
39from .triton import gen_common_triton_imports, TritonKernel
40from .triton_utils import config_of, signature_to_meta
41
42
43log = logging.getLogger(__name__)
44pexpr = PythonPrinter().doprint
45LARGE_NUMELS = 512e5
46BLOCK_UTILIZATION = 0.8
47
48
49def _default_custom_combo_kernel_horizontal_partition(
50    nodes: List[BaseSchedulerNode],
51    triton_scheduling: SIMDScheduling,
52    kernel_map: Dict[BaseSchedulerNode, TritonKernel],
53    node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]],
54) -> List[List[BaseSchedulerNode]]:
55    """Horizontally partition the given list of nodes into a list of list of nodes where each sublist
56    represents a partion. Nodes in different partitions are implemented in different combo kernels.
57    Nodes in the same partition are likely to be implemented
58    in the same combo kernel, but subject to subsequent restrictions like CUDA limits for number of args.
59
60    Input arguments:
61        nodes: a list of fused scheduler nodes to partition.
62        triton_scheduling: TritonScheduling instance.
63        kernel_map: a map from node to its kernel.
64        node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel).
65    Output:
66        a list of list of nodes with each sublist representing a partition.
67
68    The default algorithm is to partition nodes based on the following rules:
69        1) nodes with the same number of block dimensions are grouped together.
70        2) large pointwise nodes (numels greater than LARGE_NUMELS) are separated from other nodes.
71        3) large reduce nodes are separated from other nodes.
72    """
73
74    assert len(nodes) >= 1
75
76    # first partition nodes based on number of block dimensions
77    tilings = [node_info_map[n][1] for n in nodes]
78
79    max_dims = max(len(t) for t in tilings)
80    nodes_per_ndim = []
81    for i in range(2, max_dims + 1):
82        group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i]
83        reduction = [
84            n
85            for n in group_per_dim
86            if kernel_map[n].inside_reduction
87            and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim)
88        ]
89        not_reduction = [n for n in group_per_dim if n not in reduction]
90        # rnumel > 2048 usually has long execution time
91        # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes
92        long_reduction = [
93            n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048
94        ]
95        short_reduction = [n for n in reduction if n not in long_reduction]
96        if long_reduction:
97            log.warning(
98                "ComboKernels: %d long reduction nodes are separated",
99                len(long_reduction),
100            )
101        large_pointwise = [
102            n
103            for n in not_reduction
104            if not kernel_map[n].inside_reduction
105            and len(kernel_map[n].numels) == 2
106            and V.graph.sizevars.size_hint(kernel_map[n].numels[0]) > LARGE_NUMELS
107        ]
108        if large_pointwise:
109            # TODO benchmark the performance when large pointwise nodes combining with others
110            log.warning(
111                "ComboKernels: %d large pointwise nodes are separated",
112                len(large_pointwise),
113            )
114            not_reduction = [n for n in not_reduction if n not in large_pointwise]
115            for node in large_pointwise:
116                nodes_per_ndim.append([node])
117
118        for g in (not_reduction, short_reduction, long_reduction):
119            if g:
120                nodes_per_ndim.append(g)
121
122    assert sum(len(p) for p in nodes_per_ndim) == len(nodes)
123    return nodes_per_ndim
124
125
126_custom_combo_kernel_horizontal_partition_algorithm: Callable[
127    [
128        List[BaseSchedulerNode],
129        SIMDScheduling,
130        Dict[BaseSchedulerNode, TritonKernel],
131        Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]],
132    ],
133    List[List[BaseSchedulerNode]],
134] = _default_custom_combo_kernel_horizontal_partition
135
136
137def set_custom_combo_kernel_horizontal_partition(
138    algorithm: Callable[
139        [
140            List[BaseSchedulerNode],
141            SIMDScheduling,
142            Dict[BaseSchedulerNode, TritonKernel],
143            Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]],
144        ],
145        List[List[BaseSchedulerNode]],
146    ]
147) -> None:
148    """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
149    are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
150    in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args.
151
152    The algorithm should take a list of nodes and return a list of list of nodes.
153
154    The default algorithm is to partition nodes based on number of block dimensions.
155    """
156    global _custom_combo_kernel_horizontal_partition_algorithm
157    _custom_combo_kernel_horizontal_partition_algorithm = algorithm
158
159
160@dataclass
161class PartitionState:
162    partitions: List[List[BaseSchedulerNode]]
163    cur_partition: List[BaseSchedulerNode]
164    cur_count: int
165
166    def finalize(self) -> None:
167        if self.cur_partition:
168            self.partitions.append(self.cur_partition)
169
170
171class ComboKernel(Kernel):
172    MAX_NUM_ARGS = 250  # number where I would no longer get triton errors
173
174    @staticmethod
175    def _update_partition(
176        partition_state: PartitionState,
177        node_rw_count: int,
178        node_info: BaseSchedulerNode,
179    ) -> None:
180        if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS:
181            partition_state.partitions.append(partition_state.cur_partition)
182            partition_state.cur_partition = [node_info]
183            partition_state.cur_count = node_rw_count
184        else:
185            partition_state.cur_count += node_rw_count
186            partition_state.cur_partition.append(node_info)
187
188    @staticmethod
189    def _base_horizontal_partition(
190        subkernel_nodes: List[BaseSchedulerNode],
191        triton_scheduling: SIMDScheduling,
192        node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]],
193        custom_algorithm: bool,
194    ) -> List[List[BaseSchedulerNode]]:
195        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
196        for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
197        (read/writes) and to have the same 2D or 1D blocking strategy."""
198        # TODO support combination of kernels with different block dimensions
199        assert len(subkernel_nodes) >= 1
200        mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or (
201            config.combo_kernel_allow_mixed_sizes == 1 and custom_algorithm
202        )
203
204        ndim_to_partition_state: Dict[int, PartitionState] = defaultdict(
205            lambda: PartitionState([], [], 0)
206        )
207        yelem_to_partition_state: Dict[int, PartitionState] = defaultdict(
208            lambda: PartitionState([], [], 0)
209        )
210
211        for node in subkernel_nodes:
212            node_schedule, tiled_groups, numel, rnumel = node_info_map[node]
213            node_info = node
214
215            read_writes = node.read_writes
216            read_write_count = len(read_writes.reads) + len(read_writes.writes)
217
218            ndim = len(tiled_groups)
219            assert ndim >= 2, f"Combokernel not support tile {tiled_groups}"
220            if not mixed_sizes and ndim == 3:
221                y_elem = tiled_groups[0]
222                partition_state = yelem_to_partition_state[y_elem]
223                ComboKernel._update_partition(
224                    partition_state, read_write_count, node_info
225                )
226            else:
227                assert mixed_sizes or ndim <= 3, f"No mixed sizes: tile {tiled_groups}"
228                partition_state = ndim_to_partition_state[ndim]
229                ComboKernel._update_partition(
230                    partition_state, read_write_count, node_info
231                )
232
233        all_partitions = []
234        for partition_state in ndim_to_partition_state.values():
235            partition_state.finalize()
236            all_partitions.extend(partition_state.partitions)
237        for partition_state in yelem_to_partition_state.values():
238            partition_state.finalize()
239            all_partitions.extend(partition_state.partitions)
240
241        return all_partitions
242
243    @staticmethod
244    def horizontal_partition(
245        nodes: List[BaseSchedulerNode],
246        triton_scheduling: SIMDScheduling,
247        kernel_map: Dict[BaseSchedulerNode, TritonKernel],
248        node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]],
249        custom_algorithm: bool = False,
250    ) -> List[List[BaseSchedulerNode]]:
251        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum)
252        for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into
253        sublists in the following way:
254            1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True
255            2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is
256               guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same
257               2D or 1D blocking strategy.
258        """
259        if custom_algorithm:
260            raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm(
261                nodes, triton_scheduling, kernel_map, node_info_map
262            )
263        else:
264            raw_partitions = [nodes]
265
266        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
267        for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
268        (read/writes) and to have the same 2D or 1D blocking strategy."""
269        all_partitions = []
270        for raw_partition in raw_partitions:
271            all_partitions.extend(
272                ComboKernel._base_horizontal_partition(
273                    raw_partition, triton_scheduling, node_info_map, custom_algorithm
274                )
275            )
276        return all_partitions
277
278    class SequentialDispatch:
279        """
280        The dispatcher which dispatches the subkernels in a sequential manner:
281        the blocks are first dispatched to the 1st subkernel (until it is filled),
282        then to the 2nd subkernel, and so on.
283        The class defines the methods specific to the dispatch algorithm.
284        Methods:
285            codegen_pid_range(...): codegen the pid range for each subkernel.
286            grid(...): codegen the grid size for launching the combo kernel.
287        """
288
289        @classmethod
290        def codegen_pid_range(
291            cls, kernel: "ComboKernel", num: int, code: IndentedBuffer
292        ) -> None:
293            if num == 0:
294                cls._calculate_xblocks(kernel, code)
295                code.splice(f"if pid < num_xblocks_{num}:")
296                with code.indent():
297                    code.splice("pid_offset = pid")
298            else:
299                code.splice(f"elif pid < num_xblocks_{num}:")
300                with code.indent():
301                    code.splice(f"pid_offset = pid - num_xblocks_{num-1}")
302
303        @classmethod
304        def _calculate_xblocks(
305            cls, kernel: "ComboKernel", code: IndentedBuffer
306        ) -> None:
307            x_numels_list = kernel.x_numels_list
308            for i in range(len(x_numels_list)):
309                xnumels, no_x_dim = (
310                    (x_numels_list[i], False)
311                    if isinstance(x_numels_list[i], str)
312                    and cast(str, x_numels_list[i])[0] != "-"
313                    or (
314                        isinstance(x_numels_list[i], int)
315                        and cast(int, x_numels_list[i]) > 0
316                    )
317                    else (kernel.min_x_blocks_list[i], True)
318                )
319                xblock_str = (
320                    f"tl.cdiv({xnumels}, XBLOCK)" if not no_x_dim else f"{xnumels}"
321                )
322                if i == 0:
323                    code.splice(f"num_xblocks_{i} = {xblock_str}")
324                else:
325                    code.splice(f"num_xblocks_{i} = num_xblocks_{i-1} + {xblock_str}")
326
327        @classmethod
328        def grid(
329            cls,
330            sub_kernel_numels: List[List[int]],
331            x_blocks_list: List[Union[str, int]],
332            dynamic_shape: bool,
333        ) -> Tuple[Any, ...]:
334            xnumel = list(x_blocks_list)
335            ynumel: Any = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
336            znumel: Any = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]
337
338            if dynamic_shape:
339                ynumel = None if None in ynumel else ynumel
340                znumel = None if None in znumel else znumel
341            else:
342                # TODO: improve 1d/2d mixed cases
343                ynumel = (
344                    None
345                    if any(e is None for e in cast(List[Any], ynumel))
346                    else max(cast(Iterable[int], ynumel))
347                )
348                znumel = (
349                    None
350                    if any(e is None for e in cast(List[Any], znumel))
351                    else max(cast(Iterable[int], znumel))
352                )
353
354            numels = (
355                (xnumel,)
356                if not ynumel
357                else (ynumel, xnumel)
358                if not znumel
359                else (znumel, ynumel, xnumel)
360            )
361            return numels
362
363    class RoundRobinDispatch:
364        """
365        The dispatcher which dispatches the subkernels in a round robin manner:
366        the blocks are interleavedly dispatched to each subkernel to execute them
367        in parallel.
368        The class defines the methods specific to the dispatch algorithm.
369        Methods:
370            codegen_pid_range(...): codegen the pid range for each subkernel.
371            grid(...): codegen the grid size for launching the combo kernel.
372        """
373
374        @classmethod
375        def codegen_pid_range(
376            cls, kernel: "ComboKernel", num: int, code: IndentedBuffer
377        ) -> None:
378            num_kernels = len(kernel.sub_kernels)
379            if num == 0:
380                cond = "if"
381            else:
382                cond = "elif"
383            code.splice(f"{cond} pid % {num_kernels} == {num}:")
384            with code.indent():
385                code.splice(f"pid_offset = pid // {num_kernels}")
386
387        @classmethod
388        def grid(
389            cls,
390            sub_kernel_numels: List[List[int]],
391            x_blocks_list: List[Union[str, int]],
392            dynamic_shape: bool,
393        ) -> Tuple[Any, ...]:
394            xnumel = x_blocks_list
395            # set no_x_dim xnumels to 0
396            xnumel_x_dim = [max(e, 0) for e in xnumel]
397            ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
398            znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]
399
400            # TODO: support 1d/2d mixed cases
401            xnumel = (
402                None
403                if any(e is None for e in xnumel)
404                else xnumel
405                if dynamic_shape
406                else max(xnumel_x_dim)  # type: ignore[type-var, arg-type]
407            )
408            ynumel = (
409                None
410                if any(e is None for e in ynumel)
411                else ynumel
412                if dynamic_shape
413                else max(ynumel)  # type: ignore[type-var, arg-type]
414            )
415            znumel = (
416                None
417                if any(e is None for e in znumel)
418                else znumel
419                if dynamic_shape
420                else max(znumel)  # type: ignore[type-var, arg-type]
421            )
422
423            numels = (
424                (xnumel,)
425                if not ynumel
426                else (ynumel, xnumel)
427                if not znumel
428                else (znumel, ynumel, xnumel)
429            )
430            return numels
431
432    def __init__(
433        self, enable_autotune: bool = False, mixed_sizes: bool = False
434    ) -> None:
435        super().__init__()
436        self.sub_kernels: List[TritonKernel] = []
437        self.iter_vars_count = itertools.count()
438        self.grids: List[List[int]] = []
439        self.min_x_blocks_list: List[Union[int, str]] = []
440        self.x_numels_list: List[Union[int, str]] = []
441        self.enable_autotune = enable_autotune
442        self.mixed_sizes = mixed_sizes
443        self.dispatch_class: Optional[
444            Union[
445                Type[ComboKernel.SequentialDispatch],
446                Type[ComboKernel.RoundRobinDispatch],
447            ]
448        ] = None
449        self.block_args: List[str] = []
450        # there following are used when autotuning is disabled
451        self.block_size_1d = 1024  # Try tuning this value
452        self.block_size_2d = 32
453        self.num_warps = 8
454        self.block_size_reduce = 256
455        self.dynamic_shape_args: List[str] = []
456
457    def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel:
458        sub_kernel = triton_kernel
459        metrics.generated_kernel_count -= 1
460        sub_kernel.args = self.args
461        sub_kernel.iter_vars_count = self.iter_vars_count
462        sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
463        self.sub_kernels.append(sub_kernel)
464        return sub_kernel
465
466    @staticmethod
467    def create_triton_kernel(
468        *groups: Any,
469        index_dtype: str,
470        mutations: OrderedSet[str],
471        reduction_hint: ReductionHint,
472        optimize_mask: bool,
473    ) -> TritonKernel:
474        """
475        Only allow optimize_mask=True when 1) sequential dispatch is used,
476        2) numels except x dimension are the same for each sub kernel.
477        """
478        return TritonKernel(
479            *groups,
480            index_dtype=index_dtype,
481            mutations=mutations,
482            pid_cache={"tl.program_id(0)": "pid_offset"},
483            reduction_hint=reduction_hint,
484            optimize_mask=optimize_mask,
485        )
486
487    def codegen_static_numels_sub_kernel(
488        self, code: IndentedBuffer, sub_kernel: TritonKernel, num: int
489    ) -> List[str]:
490        """
491        We get a small speedup from hard coding numels if they are static.
492
493        This code stomps on the passed-in values by writing an constant to the top of the kernel.
494
495        In a kernel like:
496        def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
497
498        We would add
499        xnumel = 4096
500        rnumel = 768
501
502        After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
503        a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
504        knows that its a static numel, as that you just plop a constant into the kernel.
505        """
506        grid = []
507        uniquify_block_sizes = []
508        for tree in sub_kernel.range_trees:
509            simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
510            if isinstance(simplified_tree_numel, (Integer, int)):
511                code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
512            else:
513                assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args
514                uniquify_block_sizes.append(f"{tree.prefix}numel")
515
516            if tree.prefix != "r":
517                if isinstance(simplified_tree_numel, (Integer, int)):
518                    grid.append(int(simplified_tree_numel))
519                else:
520                    grid.append(f"{tree.prefix}numel_{num}")
521
522            if tree.prefix == "r" and sub_kernel.persistent_reduction:
523                if isinstance(simplified_tree_numel, (Integer, int)):
524                    val = int(simplified_tree_numel)
525                else:
526                    raise RuntimeError(
527                        "Dynamic shape on reduction dimension is not supported"
528                    )
529                val = next_power_of_2(val)
530                code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}")
531                uniquify_block_sizes.append("RBLOCK")
532
533            if tree.prefix == "x" and sub_kernel.no_x_dim:
534                code.writeline(f"XBLOCK_{num}: tl.constexpr = 1")
535                uniquify_block_sizes.append("XBLOCK")
536        self.grids.append(grid)
537        return uniquify_block_sizes
538
539    def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None:
540        """
541        Kernels with no_x_dim being true has no tunable XBLOCK. They have a fixed number of X blocks.
542        Grid calculation needs to make sure that they are assigned with enough number of blocks.
543        """
544        min_x_blocks: Union[int, str] = 0
545        x_numels: Union[int, str] = 0
546        for tree in sub_kernel.range_trees:
547            simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
548            if tree.prefix == "x":
549                if isinstance(simplified_tree_numel, (Integer, int)):
550                    x_numels = int(simplified_tree_numel)
551                else:
552                    x_numels = f"{tree.prefix}numel_{num}"
553                if sub_kernel.no_x_dim:
554                    min_x_blocks = x_numels
555                    x_numels = (
556                        -min_x_blocks
557                        if isinstance(x_numels, int)
558                        else "-" + cast(str, x_numels)
559                    )
560                else:
561                    if isinstance(simplified_tree_numel, (Integer, int)):
562                        x_numels = int(simplified_tree_numel)
563                    else:
564                        x_numels = f"{tree.prefix}numel_{num}"
565        self.min_x_blocks_list.append(min_x_blocks)
566        self.x_numels_list.append(x_numels)
567
568    def select_heuristics(self, sub_kernel: TritonKernel) -> Tuple[str, List[int]]:
569        size_hints = [
570            next_power_of_2(V.graph.sizevars.size_hint(numel))
571            for numel in sub_kernel.numels
572        ]
573        if sub_kernel.persistent_reduction:
574            assert sub_kernel.inside_reduction
575            heuristics = "persistent_reduction"
576        elif sub_kernel.inside_reduction:
577            heuristics = "reduction"
578        else:
579            size_hints.pop()
580            heuristics = "pointwise"
581        return heuristics, size_hints
582
583    def select_combo_heuristics(
584        self, heuristics_list: List[str], size_hints_list: List[List[int]]
585    ) -> Tuple[str, List[int], TritonKernel]:
586        if not self.enable_autotune:
587            return "foreach", size_hints_list[0], self.sub_kernels[0]
588        if "reduction" in heuristics_list:
589            i, _ = max(
590                enumerate(size_hints_list),
591                key=lambda x: x[1][0] if heuristics_list[x[0]] == "reduction" else 0,
592            )
593            return heuristics_list[i], size_hints_list[i], self.sub_kernels[i]
594        elif "pointwise" in heuristics_list:
595            i, _ = max(
596                enumerate(size_hints_list),
597                key=lambda x: x[1][0] if heuristics_list[x[0]] == "pointwise" else 0,
598            )
599            # modify size_hint to avoid oom check fail (may be a false alarm)
600            num_pointwise = len([e for e in heuristics_list if e == "pointwise"])
601            num_reduction = len([e for e in heuristics_list if e == "reduction"])
602            num_persistent_reduction = len(
603                [e for e in heuristics_list if e == "persistent_reduction"]
604            )
605            assert (
606                num_reduction == 0
607            ), "combining pointwise and reduction are not supported yet."
608            heuristics = (
609                "pointwise_with_reduction"
610                if num_persistent_reduction > 0
611                else "pointwise"
612            )
613            if len(heuristics_list) - num_pointwise >= 4:
614                size_hints = size_hints_list[i]
615                size_hints[0] = min(128, size_hints[0])
616            return heuristics, size_hints_list[i], self.sub_kernels[i]
617        else:
618            return heuristics_list[0], size_hints_list[0], self.sub_kernels[0]
619
620    def get_mutated_args_sub_kernels(self) -> List[str]:
621        mutated_args = set()
622        for sub_kernel in self.sub_kernels:
623            for mutation in sub_kernel.mutations:
624                if mutation in sub_kernel.args.input_buffers:
625                    mutated_args.add(sub_kernel.args.input_buffers[mutation])
626                if (
627                    mutation in sub_kernel.args.inplace_buffers
628                    and mutation not in V.graph.removed_buffers
629                    and mutation not in sub_kernel.removed_buffers
630                ):
631                    mutated_args.add(
632                        sub_kernel.args.inplace_buffers[mutation].inner_name
633                    )
634                if mutation in sub_kernel.args.output_buffers:
635                    mutated_args.add(sub_kernel.args.output_buffers[mutation])
636        return sorted(mutated_args)
637
638    def select_dispatch_strategy(self) -> None:
639        if self.dispatch_class is not None:
640            return
641        # mixed_sizes is used for optimize_mask, so it only allows sequential dispatch
642        # Not mixed sizes on y dim technically is ok to use round robin as wells.
643        if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list):
644            # str in min_x_blocks_list means a dynamic shape
645            self.dispatch_class = ComboKernel.SequentialDispatch
646            return
647        # A negative x_blocks_list element means the kernel is not tunable,
648        # i.e., no_x_dim = True
649        x_numels_list = [abs(cast(int, e)) for e in self.x_numels_list]
650        total = max(x_numels_list) * len(x_numels_list)
651        needed = sum(x_numels_list)
652        if needed / total > BLOCK_UTILIZATION:
653            # Introduced overhead (masked blocks) is less than 20%
654            self.dispatch_class = ComboKernel.RoundRobinDispatch
655        else:
656            self.dispatch_class = ComboKernel.SequentialDispatch
657
658    def jit_line(
659        self,
660        heuristics: str,
661        size_hints: List[int],
662        selected_kernel: TritonKernel,
663        pointwise_with_reduce: bool = False,
664        signature: Optional[List[Any]] = None,
665    ) -> str:
666        can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
667        size_dtype = "tl.int32" if can_use_32bit else "tl.int64"
668        if signature is None:
669            _, _, signature, _ = self.args.python_argdefs()
670        for i, sub in enumerate(self.sub_kernels):
671            self.min_x_blocks_sub_kernel(sub, i)
672        self.select_dispatch_strategy()
673        triton_meta = {
674            "signature": signature_to_meta(signature, size_dtype=size_dtype),
675            "device": DeviceProperties.create(
676                V.graph.scheduler.get_current_device_or_throw()
677            ),
678            "constants": {},
679        }
680        triton_meta["configs"] = [config_of(signature)]
681        mutated_args = self.get_mutated_args_sub_kernels()
682        inductor_meta = {
683            "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
684            "mutated_arg_names": mutated_args,
685            **TritonKernel.inductor_meta_common(),
686        }
687
688        sub_kernel = selected_kernel
689        if heuristics == "foreach":
690            heuristics_line = f"""
691                @triton_heuristics.foreach(
692                    num_warps={self.num_warps},
693                    triton_meta={triton_meta!r},
694                    inductor_meta={inductor_meta!r},
695                )
696                @triton.jit
697            """
698        elif sub_kernel.inside_reduction:
699            reduction_hint = sub_kernel.reduction_hint
700            heuristics_line = f"""
701                @triton_heuristics.{heuristics}(
702                    size_hints={size_hints!r},
703                    reduction_hint={reduction_hint},
704                    filename=__file__,
705                    triton_meta={triton_meta!r},
706                    inductor_meta={inductor_meta!r}
707                )
708                @triton.jit
709            """
710        else:
711            tile_hint = ""
712            if len(size_hints) == 2:
713                tile_hint = "tile_hint=TileHint.SQUARE,"
714            else:
715                tile_hint = "tile_hint=TileHint.DEFAULT,"
716            heuristics_line = f"""
717                @triton_heuristics.{heuristics}(
718                    size_hints={size_hints!r}, {tile_hint}
719                    filename=__file__,
720                    triton_meta={triton_meta!r},
721                    inductor_meta={inductor_meta!r}
722                )
723                @triton.jit
724            """
725
726        return heuristics_line
727
728    def codegen_blocks(self, code: IndentedBuffer) -> None:
729        for block in self.block_args:
730            assert block in [
731                "XBLOCK",
732                "YBLOCK",
733                "RBLOCK",
734            ], f"{block} is not supported without autotuning"
735        if "YBLOCK" in self.block_args:
736            code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
737            code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
738        else:
739            code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")
740        if "RBLOCK" in self.block_args:
741            code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}")
742
743    def add_blockd_to_args(self, argdefs: List[str]) -> List[str]:
744        block_args = {}
745        block_names = {}
746        for num, sub_kernel in enumerate(self.sub_kernels):
747            # TODO: we assume all sub_kernels have the same block size
748            for tree in sub_kernel.range_trees:
749                if tree.prefix == "r" and (
750                    not sub_kernel.inside_reduction or sub_kernel.persistent_reduction
751                ):
752                    continue
753                if tree.prefix == "x" and sub_kernel.no_x_dim:
754                    continue
755                block_args[f"{tree.prefix.upper()}BLOCK : tl.constexpr"] = tree.prefix
756                block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix
757        if self.enable_autotune:
758            argdefs.extend(block_args)
759        self.block_args = list(block_names.keys())
760        return argdefs
761
762    def add_numel_to_args(self, argdefs: List[str], signature: List[Any]) -> List[str]:
763        for num, sub_kernel in enumerate(self.sub_kernels):
764            for tree in sub_kernel.active_range_trees():
765                if not isinstance(tree.numel, (Integer, int)):
766                    # only if it is a dynamic shape
767                    sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel)
768                    signature.append(sizearg)
769                    argdefs.append(f"{tree.prefix}numel_{num}")
770                    self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}")
771        return argdefs
772
773    def add_numel_to_call_args_and_grid(
774        self, name: str, call_args: List[Any], arg_types: List[Any], grid: List[Any]
775    ) -> None:
776        for num, sub_kernel in enumerate(self.sub_kernels):
777            for i, tree in enumerate(sub_kernel.range_trees):
778                numel_name = f"{tree.prefix}numel_{num}"
779                if numel_name not in self.dynamic_shape_args:
780                    continue
781                if isinstance(tree.numel, (Integer, Symbol)):
782                    expr = tree.numel
783                else:
784                    expr = V.graph.wrapper_code.generate_numel_expr(
785                        name, tree, suffix=str(num)
786                    )
787                if tree.prefix != "r":
788                    assert isinstance(
789                        grid[i][num], str
790                    ), f"Grid {grid[i][num]} should be a dynamic shape."
791                    numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
792                    assert (
793                        grid[i][num] == numel_sign + numel_name
794                    ), f"numel args mismatch: {grid[i][num]} vs {numel_name}"
795                    grid[i][num] = -expr if numel_sign == "-" else expr
796
797                if tree.prefix != "r" or sub_kernel.inside_reduction:
798                    call_args.append(expr)
799                    arg_types.append(type(expr))
800
801    def add_numel_to_call_args_and_grid_benchmark(
802        self, extra_args: List[Any], grid: Union[List[Any], Tuple[Any, ...]]
803    ) -> None:
804        for num, sub_kernel in enumerate(self.sub_kernels):
805            for i, tree in enumerate(sub_kernel.range_trees):
806                numel_name = f"{tree.prefix}numel_{num}"
807                if numel_name not in self.dynamic_shape_args:
808                    continue
809                expr = V.graph.sizevars.size_hint(tree.numel)
810                if tree.prefix != "r":
811                    assert isinstance(
812                        grid[i][num], str
813                    ), f"Grid {grid[i][num]} should be a dynamic shape."
814                    numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
815                    assert (
816                        grid[i][num] == numel_sign + numel_name
817                    ), f"grid mismatch: {grid[i][num]} vs {numel_name}"
818                    grid[i][num] = -expr if numel_sign == "-" else expr
819                if tree.prefix != "r" or sub_kernel.inside_reduction:
820                    extra_args.append(expr)
821
822    def codegen_kernel(self, name: Optional[str] = None) -> str:
823        # TODO: is it correct to use the first sub kernel's heuristics?
824        heuristics_list, size_hints_list = [], []
825        for subkernel in self.sub_kernels:
826            h, s = self.select_heuristics(subkernel)
827            heuristics_list.append(h)
828            size_hints_list.append(s)
829        heuristics, size_hints, selected_kernel = self.select_combo_heuristics(
830            heuristics_list, size_hints_list
831        )
832        pointwise_with_reduction, heuristics = (
833            (True, "pointwise")
834            if heuristics == "pointwise_with_reduction"
835            else (False, heuristics)
836        )
837        code = IndentedBuffer()
838
839        code.splice(gen_common_triton_imports())
840        if config.benchmark_combo_kernel:
841            code.splice(self.imports_for_benchmark_kernel())
842
843        argdefs, _, signature, _ = self.args.python_argdefs()
844        argdefs = self.add_numel_to_args(argdefs, signature)
845        argdefs = self.add_blockd_to_args(argdefs)
846        code.splice(
847            self.jit_line(
848                heuristics,
849                size_hints,
850                selected_kernel,
851                pointwise_with_reduce=pointwise_with_reduction,
852                signature=signature,
853            )
854        )
855        code.writeline(
856            f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
857        )
858
859        with code.indent():
860            code.splice("pid = tl.program_id(0)")
861            if not self.enable_autotune:
862                self.codegen_blocks(code)
863
864            for num, sub_kernel in enumerate(self.sub_kernels):
865                assert self.dispatch_class is not None
866                self.dispatch_class.codegen_pid_range(self, num, code)
867                with code.indent():
868                    uniquify = self.codegen_static_numels_sub_kernel(
869                        code, sub_kernel, num
870                    )
871                    sub_kernel.codegen_body()
872                    uniquified_body = self.uniquify_block_sizes(
873                        sub_kernel.body, num, uniquify
874                    )
875                    code.splice(uniquified_body)
876
877            code.splice("else:")
878            with code.indent():
879                code.splice("pass")
880
881        if config.benchmark_combo_kernel:
882            code.splice(self.codegen_kernel_benchmark(num_gb=0))
883
884        return code.getvalue()
885
886    def codegen_kernel_benchmark(
887        self, num_gb: float, grid: Optional[List[Any]] = None
888    ) -> IndentedBuffer:
889        result = IndentedBuffer()
890        argdefs, call_args, signature, _ = self.args.python_argdefs()
891
892        result.writelines(["", "", "def get_args():"])
893        with result.indent():
894            name_cnt = itertools.count()
895            var_names = []
896            for arg_name, arg_sig in zip(call_args, signature):
897                var_name = f"arg_{next(name_cnt)}"
898                buf = V.graph.try_get_buffer(arg_name)
899                if buf:
900                    result.writeline(
901                        f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})"  # noqa: B950 line too long
902                    )
903                elif arg_name in V.graph.constants:
904                    # note that random seed is put in V.graph.constants
905                    const_tensor = V.graph.constants[arg_name]
906                    result.writeline(
907                        f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})"  # type: ignore[arg-type]  # noqa: B950 line too long
908                    )
909                elif isinstance(arg_sig, SizeArg):
910                    symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
911
912                    # Force the seed_offset to be 0 so calls to the same kernel
913                    # using different seed offset will have the same benchmark harness.
914                    # We can dedup kernel definitions in this case.
915                    if "seed_offset" in arg_sig.name:
916                        symval_hint = 0
917                    result.writeline(f"{var_name} = {symval_hint}")
918                elif isinstance(arg_sig, WorkspaceArg):
919                    device = V.graph.scheduler.get_current_device_or_throw()
920                    nbytes = V.graph.sizevars.size_hint(arg_sig.nbytes)
921                    result.writeline(
922                        f"{var_name} = torch.zeros({nbytes}, device='{device}', dtype=torch.uint8)"
923                    )
924                else:
925                    raise KeyError(
926                        f"Don't find the buffer or const tensor for {arg_name}"
927                    )
928                var_names.append(var_name)
929            result.writeline(f"return {', '.join(var_names)},")
930
931        result.writelines(["\n", "\n", "def call(args):"])
932        if grid is None:
933            assert self.dispatch_class is not None
934            dynamic_shape = self.dynamic_shape_args != []
935            grid_tuple = self.dispatch_class.grid(
936                self.grids, self.x_numels_list, dynamic_shape
937            )
938            extra_args_str = ""
939            extra_args: List[Any] = []
940            if dynamic_shape:
941                self.add_numel_to_call_args_and_grid_benchmark(extra_args, grid_tuple)
942                # convert nested list to list of str
943                grid_tuple = tuple(
944                    "[" + ", ".join(pexpr(item) for item in e) + ",]"
945                    for e in grid_tuple
946                )
947                extra_args_str = ", ".join(map(str, extra_args)) + ", "
948                min_blocks = None
949            else:
950                min_blocks = max(self.min_x_blocks_list) * len(self.sub_kernels)
951            grid_str = ", ".join(pexpr(item) for item in grid_tuple)
952            grid_extra_kwargs = (
953                f"num_kernels={len(self.sub_kernels)}, "
954                f"min_blocks={min_blocks}, "
955                f"is_sequential={self.dispatch_class is self.SequentialDispatch}"
956            )
957            grid_str = f"{grid_str}, {grid_extra_kwargs}"
958            grid_arg = f"{extra_args_str}grid=grid_combo_kernels({grid_str})"
959        else:
960            grid_arg = f"grid={grid}"
961        index = V.graph.scheduler.get_current_device_or_throw().index
962        with result.indent():
963            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
964            with result.indent():
965                result.writeline(
966                    V.graph.device_ops.set_device(index)
967                )  # no-op to ensure context
968                stream_name = f"stream{index}"
969                result.writeline(f"{stream_name} = get_raw_stream({index})")
970                result.writeline(
971                    f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
972                )
973
974        # benchmark all configs
975        result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
976        with result.indent():
977            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
978            with result.indent():
979                result.writeline(
980                    V.graph.device_ops.set_device(index)
981                )  # no-op to ensure context
982                result.writeline(
983                    f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
984                )
985
986        result.writelines(["\n", "\n", "if __name__ == '__main__':"])
987        with result.indent():
988            result.writeline(
989                "from torch._inductor.runtime.benchmarking import benchmarker"
990            )
991            result.writeline("")
992
993            result.writeline("args = get_args()")
994            result.writeline(
995                "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)"
996            )
997            result.writeline(f"num_gb = {num_gb}")
998            result.writeline("gb_per_s = num_gb / (ms / 1e3)")
999            result.writeline(
1000                'print(f"{ms:.3f}ms    {num_gb:.3f}GB    {gb_per_s:.2f}GB/s")'
1001            )
1002
1003        return result
1004
1005    def imports_for_benchmark_kernel(self) -> str:
1006        return textwrap.dedent(
1007            """
1008            from torch._dynamo.testing import rand_strided
1009            {}
1010            import torch
1011            from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
1012        """.format(
1013                V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
1014            )
1015        )
1016
1017    def uniquify_block_sizes(
1018        self, code: IndentedBuffer, num_kernel: int, uniquify: List[str]
1019    ) -> IndentedBuffer:
1020        if not uniquify:
1021            return code
1022        modified = IndentedBuffer(initial_indent=code._indent)
1023        for line in code._lines:
1024            if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]):
1025                modified_line = line
1026                for block in blocks:
1027                    modified_line = modified_line.replace(
1028                        block, f"{block}_{num_kernel}"
1029                    )
1030                modified.writeline(modified_line)
1031            elif isinstance(line, DeferredLine) and (
1032                blocks := [e for e in uniquify if e in line.line]
1033            ):
1034                modified_line = line.line
1035                for block in blocks:
1036                    modified_line = modified_line.replace(
1037                        block, f"{block}_{num_kernel}"
1038                    )
1039                new_line = DeferredLine(line.name, modified_line)
1040                modified.writeline(new_line)
1041            else:
1042                modified.writeline(line)
1043        return modified
1044
1045    def call_kernel(self, code: IndentedBuffer, name: str) -> None:
1046        _, call_args, _, arg_types = self.args.python_argdefs()
1047
1048        wrapper = V.graph.wrapper_code
1049        assert self.dispatch_class is not None
1050        dynamic_shape = self.dynamic_shape_args != []
1051        grid = list(
1052            self.dispatch_class.grid(self.grids, self.x_numels_list, dynamic_shape)
1053        )
1054        num_kernels = len(self.sub_kernels)
1055        min_blocks = (
1056            max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None
1057        )
1058        is_sequential = self.dispatch_class is self.SequentialDispatch
1059        if dynamic_shape:
1060            self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid)
1061            # convert nested list to list of str
1062            # grid = tuple("["+", ".join(pexpr(item) for item in e)+",]" for e in grid)
1063        if not self.enable_autotune and not dynamic_shape:
1064            launch_grid = self.grid_no_autotune(
1065                grid, num_kernels, cast(int, min_blocks), is_sequential
1066            )
1067            V.graph.wrapper_code.generate_kernel_call(
1068                name,
1069                call_args,
1070                grid=launch_grid,
1071                arg_types=arg_types,
1072                grid_fn="",
1073            )
1074            return
1075        # autotuning is enabled
1076        grid = wrapper.generate_default_grid(
1077            name,
1078            list(grid),
1079            grid_callable=grid_combo_kernels,
1080            num_kernels=num_kernels,
1081            min_blocks=min_blocks,
1082            is_sequential=is_sequential,
1083            default_meta=None if self.enable_autotune else self.get_default_meta(),
1084        )
1085        wrapper.generate_kernel_call(
1086            name,
1087            call_args,
1088            grid,
1089            V.graph.scheduler.get_current_device_or_throw().index,
1090            cuda=True,
1091            triton=True,
1092            arg_types=arg_types,
1093            grid_fn="grid_combo_kernels",
1094            grid_extra_kwargs=(
1095                f"num_kernels={num_kernels}, "
1096                f"min_blocks={min_blocks}, "
1097                f"is_sequential={is_sequential}, "
1098                f"default_meta={None if self.enable_autotune else self.get_default_meta()}"
1099            ),
1100        )
1101
1102    def grid_no_autotune(
1103        self,
1104        grid: Union[Tuple[Any], List[Any]],
1105        num_kernels: int,
1106        min_blocks: int,
1107        is_sequential: bool,
1108    ) -> List[int]:
1109        meta = self.get_default_meta()
1110        grid_func = grid_combo_kernels(
1111            *grid,
1112            num_kernels=num_kernels,
1113            min_blocks=min_blocks,
1114            is_sequential=is_sequential,
1115        )
1116        return grid_func(meta)
1117
1118    def get_default_meta(self) -> Dict[str, int]:
1119        if "YBLOCK" in self.block_args:
1120            meta = {"XBLOCK": self.block_size_2d, "YBLOCK": self.block_size_2d}
1121        else:
1122            meta = {"XBLOCK": self.block_size_1d}
1123        return meta
1124