1# mypy: allow-untyped-defs 2import functools 3from typing import Optional 4 5import torch._inductor.runtime.hints 6from torch._inductor import config 7from torch._inductor.codegen.simd import IterationRangesRoot 8from torch._inductor.codegen.triton import triton_compute_type, TritonKernel 9from torch._prims_common import prod 10from torch.utils._ordered_set import OrderedSet 11from torch.utils._sympy.functions import CeilDiv 12 13 14class TritonSplitScanKernel(TritonKernel): 15 """Generates a triton kernel that supports ops.scan calls while also splitting 16 the reduction dimension over multiple triton programs. 17 18 For this kernel, loop numels will always take the form ``(xdim, rdim)`` 19 and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication 20 between blocks occurs within a global memory workspace buffer, which 21 must be zero-filled before launching the kernel. 22 23 Note that generation for ``ops.reduction`` is not supported. 24 25 For details of the communication strategy, see 26 https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back 27 28 """ 29 30 def __init__( 31 self, 32 *groups, 33 index_dtype: str, 34 mutations: Optional[OrderedSet[str]] = None, 35 reduction_hint=torch._inductor.runtime.hints.ReductionHint.DEFAULT, 36 min_elem_per_thread=0, 37 ) -> None: 38 super().__init__( 39 *groups, 40 index_dtype=index_dtype, 41 mutations=mutations, 42 pid_cache=None, 43 reduction_hint=reduction_hint, 44 min_elem_per_thread=min_elem_per_thread, 45 ) 46 self.no_x_dim = True 47 48 def should_use_persistent_reduction(self) -> bool: 49 return False 50 51 def initialize_range_tree(self, pid_cache): 52 prefixes = "yxr" 53 assert len(self.numels) <= len( 54 prefixes 55 ), "z dimension not supported for split scan" 56 active_prefixes = prefixes[len(prefixes) - len(self.numels) :] 57 58 grid_dims = "rxy" 59 for numel, prefix in zip(self.numels, active_prefixes): 60 is_reduction = prefix == "r" 61 tensor_dim = 0 if is_reduction else None 62 grid_dim = grid_dims.find(prefix) 63 self.range_trees.append( 64 IterationRangesRoot( 65 f"{prefix}index", 66 numel, 67 prefix, 68 grid_dim, 69 self, 70 pid_cache=pid_cache, 71 is_loop=False, 72 tensor_dim=tensor_dim, 73 grid_dim=grid_dim, 74 has_zdim=False, 75 ) 76 ) 77 78 def reduction(self, dtype, src_dtype, reduction_type, value): 79 raise NotImplementedError("NYI TritonSplitDimKernel reductions") 80 81 def scan(self, dtypes, combine_fn, values): 82 import triton.language as tl 83 84 (dtype,) = dtypes 85 (value,) = values 86 87 compute_type = triton_compute_type(dtype) 88 compute_type_triton = getattr(tl, compute_type[3:]) 89 90 element_nbits = compute_type_triton.primitive_bitwidth 91 92 scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64" 93 scratch_type_triton = getattr(tl, scratch_type[3:]) 94 scratch_elems_per_block = 3 if element_nbits == 64 else 1 95 scratch_nbytes_per_block = scratch_elems_per_block * ( 96 scratch_type_triton.primitive_bitwidth // 8 97 ) 98 99 cse_load = functools.partial(self.cse.generate, self.loads) 100 cse_compute = functools.partial(self.cse.generate, self.compute) 101 102 assert len(self.numels) == 2, "Unexpected tiling" 103 min_rblock = config.triton.min_split_scan_rblock 104 max_blocks = prod(self.numels[:-1]) * CeilDiv(self.numels[-1], min_rblock) 105 nbytes = scratch_nbytes_per_block * max_blocks 106 scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True) 107 if offset != 0: 108 scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}") 109 runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})") 110 scratch_base = cse_load( 111 f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * " 112 f"{scratch_elems_per_block} * {runtime_rblocks}" 113 ) 114 115 masks = {f"{tree.prefix}mask" for tree in self.range_trees} 116 self.filter_masks(masks) 117 masks = sorted(masks) 118 assert not self._load_mask, "ops.scan not supported inside ops.masked" 119 120 value = cse_compute(f"{value}.to({compute_type})") 121 value = cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})") 122 123 combine_helper_fn = self._lift_helper(combine_fn, 1) 124 dim = self.triton_tensor_ndim() - 1 125 assert dim == 0, "" 126 127 block_sum = cse_compute(f"tl.reduce({value}, {dim}, {combine_helper_fn})") 128 exclusive_prefix = self.cse.newvar() 129 if element_nbits == 64: 130 self.compute.splice( 131 f""" 132 {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64( 133 {scratch_base}, 134 {block_sum}, 135 {self.iteration_ranges_get_pid(self.range_trees[-1])}, 136 {combine_helper_fn}, 137 ) 138 """, 139 strip=True, 140 ) 141 142 else: 143 assert element_nbits <= 32 144 value_as_uint_dtype = f"tl.uint{element_nbits}" 145 146 self.compute.splice( 147 f""" 148 {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback( 149 {scratch_base}, 150 {block_sum}, 151 {self.iteration_ranges_get_pid(self.range_trees[-1])}, 152 {combine_helper_fn}, 153 DTYPE_VALUE_AS_UINT={value_as_uint_dtype}, 154 DTYPE_PACK={scratch_type}, 155 ) 156 """, 157 strip=True, 158 ) 159 # Compute final cumsum 160 block_scan = cse_compute( 161 f"tl.associative_scan({value}, {dim}, {combine_helper_fn})" 162 ) 163 combined_result = cse_compute( 164 f"{combine_helper_fn}({exclusive_prefix}, {block_scan})" 165 ) 166 return ( 167 cse_compute(f"tl.where(roffset == 0, {block_scan}, {combined_result})"), 168 ) 169 170 def _get_heuristic(self): 171 return "split_scan" 172 173 def _get_grid_fn(self): 174 return "split_scan_grid" 175