xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/triton_split_scan.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from typing import Optional
4
5import torch._inductor.runtime.hints
6from torch._inductor import config
7from torch._inductor.codegen.simd import IterationRangesRoot
8from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
9from torch._prims_common import prod
10from torch.utils._ordered_set import OrderedSet
11from torch.utils._sympy.functions import CeilDiv
12
13
14class TritonSplitScanKernel(TritonKernel):
15    """Generates a triton kernel that supports ops.scan calls while also splitting
16    the reduction dimension over multiple triton programs.
17
18    For this kernel, loop numels will always take the form ``(xdim, rdim)``
19    and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication
20    between blocks occurs within a global memory workspace buffer, which
21    must be zero-filled before launching the kernel.
22
23    Note that generation for ``ops.reduction`` is not supported.
24
25    For details of the communication strategy, see
26    https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
27
28    """
29
30    def __init__(
31        self,
32        *groups,
33        index_dtype: str,
34        mutations: Optional[OrderedSet[str]] = None,
35        reduction_hint=torch._inductor.runtime.hints.ReductionHint.DEFAULT,
36        min_elem_per_thread=0,
37    ) -> None:
38        super().__init__(
39            *groups,
40            index_dtype=index_dtype,
41            mutations=mutations,
42            pid_cache=None,
43            reduction_hint=reduction_hint,
44            min_elem_per_thread=min_elem_per_thread,
45        )
46        self.no_x_dim = True
47
48    def should_use_persistent_reduction(self) -> bool:
49        return False
50
51    def initialize_range_tree(self, pid_cache):
52        prefixes = "yxr"
53        assert len(self.numels) <= len(
54            prefixes
55        ), "z dimension not supported for split scan"
56        active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
57
58        grid_dims = "rxy"
59        for numel, prefix in zip(self.numels, active_prefixes):
60            is_reduction = prefix == "r"
61            tensor_dim = 0 if is_reduction else None
62            grid_dim = grid_dims.find(prefix)
63            self.range_trees.append(
64                IterationRangesRoot(
65                    f"{prefix}index",
66                    numel,
67                    prefix,
68                    grid_dim,
69                    self,
70                    pid_cache=pid_cache,
71                    is_loop=False,
72                    tensor_dim=tensor_dim,
73                    grid_dim=grid_dim,
74                    has_zdim=False,
75                )
76            )
77
78    def reduction(self, dtype, src_dtype, reduction_type, value):
79        raise NotImplementedError("NYI TritonSplitDimKernel reductions")
80
81    def scan(self, dtypes, combine_fn, values):
82        import triton.language as tl
83
84        (dtype,) = dtypes
85        (value,) = values
86
87        compute_type = triton_compute_type(dtype)
88        compute_type_triton = getattr(tl, compute_type[3:])
89
90        element_nbits = compute_type_triton.primitive_bitwidth
91
92        scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64"
93        scratch_type_triton = getattr(tl, scratch_type[3:])
94        scratch_elems_per_block = 3 if element_nbits == 64 else 1
95        scratch_nbytes_per_block = scratch_elems_per_block * (
96            scratch_type_triton.primitive_bitwidth // 8
97        )
98
99        cse_load = functools.partial(self.cse.generate, self.loads)
100        cse_compute = functools.partial(self.cse.generate, self.compute)
101
102        assert len(self.numels) == 2, "Unexpected tiling"
103        min_rblock = config.triton.min_split_scan_rblock
104        max_blocks = prod(self.numels[:-1]) * CeilDiv(self.numels[-1], min_rblock)
105        nbytes = scratch_nbytes_per_block * max_blocks
106        scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True)
107        if offset != 0:
108            scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}")
109        runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})")
110        scratch_base = cse_load(
111            f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * "
112            f"{scratch_elems_per_block} * {runtime_rblocks}"
113        )
114
115        masks = {f"{tree.prefix}mask" for tree in self.range_trees}
116        self.filter_masks(masks)
117        masks = sorted(masks)
118        assert not self._load_mask, "ops.scan not supported inside ops.masked"
119
120        value = cse_compute(f"{value}.to({compute_type})")
121        value = cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})")
122
123        combine_helper_fn = self._lift_helper(combine_fn, 1)
124        dim = self.triton_tensor_ndim() - 1
125        assert dim == 0, ""
126
127        block_sum = cse_compute(f"tl.reduce({value}, {dim}, {combine_helper_fn})")
128        exclusive_prefix = self.cse.newvar()
129        if element_nbits == 64:
130            self.compute.splice(
131                f"""
132                {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64(
133                    {scratch_base},
134                    {block_sum},
135                    {self.iteration_ranges_get_pid(self.range_trees[-1])},
136                    {combine_helper_fn},
137                )
138                """,
139                strip=True,
140            )
141
142        else:
143            assert element_nbits <= 32
144            value_as_uint_dtype = f"tl.uint{element_nbits}"
145
146            self.compute.splice(
147                f"""
148                {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback(
149                    {scratch_base},
150                    {block_sum},
151                    {self.iteration_ranges_get_pid(self.range_trees[-1])},
152                    {combine_helper_fn},
153                    DTYPE_VALUE_AS_UINT={value_as_uint_dtype},
154                    DTYPE_PACK={scratch_type},
155                )
156                """,
157                strip=True,
158            )
159        # Compute final cumsum
160        block_scan = cse_compute(
161            f"tl.associative_scan({value}, {dim}, {combine_helper_fn})"
162        )
163        combined_result = cse_compute(
164            f"{combine_helper_fn}({exclusive_prefix}, {block_scan})"
165        )
166        return (
167            cse_compute(f"tl.where(roffset == 0, {block_scan}, {combined_result})"),
168        )
169
170    def _get_heuristic(self):
171        return "split_scan"
172
173    def _get_grid_fn(self):
174        return "split_scan_grid"
175