1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import collections 5import contextlib 6import dataclasses 7import functools 8import itertools 9import logging 10import math 11import operator 12from typing import ( 13 Any, 14 Callable, 15 Counter, 16 DefaultDict, 17 Dict, 18 Iterable, 19 List, 20 Optional, 21 Sequence, 22 Tuple, 23 Union, 24) 25 26import sympy 27 28import torch 29import torch._logging 30from torch.utils._ordered_set import OrderedSet 31from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing 32from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT 33 34from ..._dynamo.utils import counters 35from .. import config, ir, scheduler 36from ..codecache import code_hash 37from ..dependencies import Dep, MemoryDep, StarDep, WeakDep 38from ..ir import IRNode, TritonTemplateBuffer 39from ..optimize_indexing import indexing_dtype_strength_reduction 40from ..runtime.hints import ReductionHint 41from ..runtime.runtime_utils import green_text, yellow_text 42from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse 43from ..utils import ( 44 get_dtype_size, 45 IndentedBuffer, 46 Placeholder, 47 sympy_index_symbol, 48 sympy_product, 49 sympy_subs, 50 unique, 51) 52from ..virtualized import ops, OpsWrapper, V 53from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter 54from .multi_kernel import MultiKernel 55 56 57log = logging.getLogger(__name__) 58perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") 59schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") 60fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") 61 62 63pexpr = PythonPrinter().doprint 64 65 66@dataclasses.dataclass 67class IterationRanges: 68 """ 69 Each range tree represents multiple sets of iteration indexing 70 in a single tiled dimension in the output kernel. 71 72 If you have two loops ranges one (4, 3, 2) and another (4, 6), 73 then the range tree will be: 74 4 (i0) 75 3 (i1) 6 (i3) 76 2 (i2) 77 Where i0 is shared between both loops, but then the split into 78 different indexing vars. All loop ranges must iterate over 79 the same number of elements. 80 """ 81 82 def __init__( 83 self, 84 name: str, 85 var_list: List[sympy.Symbol], 86 var_ranges: Dict[sympy.Symbol, sympy.Expr], 87 numel: sympy.Expr, 88 prefix: str, 89 *, 90 kernel: SIMDKernel, 91 divisor=sympy.Integer(1), 92 length=sympy.Integer(1), 93 root: IterationRangesRoot, 94 ) -> None: 95 super().__init__() 96 self.name = name 97 self.var_list = var_list 98 self.var_ranges = var_ranges 99 self.numel = numel 100 self.prefix = prefix 101 self.divisor = divisor 102 self.length = length 103 self.kernel = kernel 104 self.root = root 105 106 def symbol(self): 107 return sympy_index_symbol(self.name) 108 109 110class IterationRangesRoot(IterationRanges): 111 def __init__( 112 self, 113 name: str, 114 numel: sympy.Expr, 115 # TODO: this is probably SymTy.INDEX and SymTy.RINDEX 116 prefix: str, 117 index: int, 118 kernel: SIMDKernel, 119 pid_cache=None, 120 *, 121 is_loop: bool, 122 tensor_dim: Optional[int], 123 grid_dim: Optional[int], 124 has_zdim: bool, 125 ) -> None: 126 if pid_cache is None: 127 pid_cache = {} 128 super().__init__( 129 name=name, 130 var_list=[], 131 var_ranges={}, 132 numel=numel, 133 prefix=prefix, 134 kernel=kernel, 135 root=self, 136 ) 137 self.index = index 138 # Store all the nodes in one flat list 139 self.nodes: Dict[sympy.Expr, IterationRangesEntry] = {} 140 # This is for re-ordering program ID in triton mm template 141 # pid_cache["tl.program_id(0)"] = pid_m 142 self.pid_cache: Dict[str, str] = pid_cache 143 144 # True if the dimension is implemented as a single program looping over 145 # the full dimension (currently only used for non-persistent reduction) 146 assert not is_loop or (prefix == "r" and grid_dim is None) 147 self.is_loop = is_loop 148 # Index of corresponding dimension on triton tensors 149 self.tensor_dim = tensor_dim 150 # Index of corresponding dimension in the triton grid 151 self.grid_dim = grid_dim 152 self.has_zdim = has_zdim 153 154 def __repr__(self) -> str: 155 return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" 156 157 def cache_clear(self): 158 for node in self.nodes.values(): 159 node.cache_clear() 160 161 def index_sym(self): 162 return sympy_index_symbol(f"{self.prefix}index") 163 164 def lookup(self, divisor, length): 165 """ 166 Lookup a given RangeTreeEntry, creating it if needed 167 """ 168 if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): 169 expr = FloorDiv(self.index_sym(), divisor) 170 else: 171 expr = ModularIndexing(self.index_sym(), divisor, length) 172 173 if expr not in self.nodes: 174 node = IterationRangesEntry( 175 f"{self.prefix}{next(V.kernel.iter_vars_count)}", 176 divisor, 177 length, 178 expr, 179 self, 180 ) 181 V.kernel.range_tree_nodes[node.symbol()] = node 182 self.var_list.append(node.symbol()) 183 self.var_ranges[node.symbol()] = length 184 self.nodes[expr] = node 185 return self.nodes[expr] 186 187 def construct_entries(self, lengths: List[sympy.Expr]): 188 divisor = sympy.Integer(1) 189 itervars = [] 190 for length in reversed(lengths): 191 itervars.append(self.lookup(divisor, length)) 192 divisor = divisor * length 193 return list(reversed(itervars)) 194 195 def construct(self, lengths: List[sympy.Expr]): 196 return [e.symbol() for e in self.construct_entries(lengths)] 197 198 def vars_and_sizes(self, index: sympy.Expr): 199 """Figure out vars from this tree used in index""" 200 nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] 201 nodes = [n for n in nodes if n and n.prefix == self.prefix] 202 nodes.sort( 203 key=lambda x: V.graph.sizevars.size_hint( 204 x.divisor, fallback=config.unbacked_symint_fallback 205 ) 206 ) 207 divisor = sympy.Integer(1) 208 index_vars = [] 209 sizes = [] 210 211 def add(node): 212 nonlocal divisor 213 index_vars.append(node.symbol()) 214 sizes.append(node.length) 215 divisor = divisor * node.length 216 217 for node in nodes: 218 if not V.graph.sizevars.statically_known_equals(node.divisor, divisor): 219 # fill in unused index var 220 add(self.lookup(divisor, FloorDiv(node.divisor, divisor))) 221 divisor = node.divisor 222 add(node) 223 if not V.graph.sizevars.statically_known_equals(self.numel, divisor): 224 # fill in unused index var 225 add(self.lookup(divisor, FloorDiv(self.numel, divisor))) 226 227 return list(reversed(index_vars)), list(reversed(sizes)) 228 229 230class IterationRangesEntry(IterationRanges): 231 def __init__( 232 self, 233 name: str, 234 divisor: sympy.Expr, 235 length: sympy.Expr, 236 expr: sympy.Expr, 237 parent: IterationRanges, 238 ) -> None: 239 super().__init__( 240 name=name, 241 numel=parent.numel / length, 242 var_list=parent.var_list, 243 var_ranges=parent.var_ranges, 244 prefix=parent.prefix, 245 divisor=divisor, 246 length=length, 247 kernel=parent.kernel, 248 root=parent.root, 249 ) 250 self.parent = parent 251 self.codegen = functools.lru_cache(None)(self._codegen) 252 self.expr = expr 253 254 def __repr__(self) -> str: 255 return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" 256 257 def set_name(self, name): 258 self.codegen = lambda: name # type: ignore[assignment] 259 self.codegen.cache_clear = lambda: None # type: ignore[method-assign] 260 self.name = name 261 262 def cache_clear(self): 263 self.codegen.cache_clear() 264 265 def _codegen(self): 266 V.kernel.codegen_iteration_ranges_entry(self) 267 return self.name 268 269 def precomputed_args(self): 270 # for dynamic shapes, find parts of indexing expressions that have to be precomputed 271 precomputed_args: List[sympy.Expr] = [] 272 if isinstance(self.expr, sympy.Symbol): 273 return precomputed_args 274 assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr) 275 for arg in self.expr.args[1:]: 276 if not isinstance(arg, (sympy.Integer, sympy.Symbol)): 277 symbols = arg.free_symbols 278 if len(symbols) > 0 and all( 279 symbol_is_type(s, SymT.SIZE) for s in symbols 280 ): 281 precomputed_args.append(arg) 282 return precomputed_args 283 284 def __hash__(self): 285 return hash(self.name) 286 287 def __eq__(self, other): 288 return self.name == other.name 289 290 291def constant_repr(value): 292 if value == float("inf"): 293 return 'float("inf")' 294 elif value == float("-inf"): 295 return 'float("-inf")' 296 elif math.isnan(value): 297 return 'float("nan")' 298 return repr(value) 299 300 301class SIMDKernel(Kernel): 302 """ 303 Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests. 304 """ 305 306 sexpr = pexpr 307 kexpr: Callable[[sympy.Expr], str] 308 allow_block_ptr = False 309 310 def __init__( 311 self, 312 *groups, 313 index_dtype: str, 314 mutations: Optional[OrderedSet[str]] = None, 315 pid_cache=None, 316 reduction_hint=ReductionHint.DEFAULT, 317 override_persistent_reduction=None, 318 ) -> None: 319 if pid_cache is None: 320 pid_cache = {} 321 super().__init__() 322 self.body = IndentedBuffer() 323 self.indexing_code = IndentedBuffer() 324 self.numels = [V.graph.sizevars.simplify(s) for s in groups] 325 self.mutations: OrderedSet[str] = ( 326 mutations if mutations is not None else OrderedSet() 327 ) 328 self.range_trees: List[IterationRangesRoot] = [] 329 self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {} 330 self.iter_vars_count = itertools.count() 331 self.inside_reduction = self.numels[-1] != 1 332 self.reduction_hint = reduction_hint 333 self.index_dtype: str = index_dtype 334 self.last_usage: OrderedSet[str] = OrderedSet() 335 self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list) 336 self.persistent_reduction: bool = ( 337 override_persistent_reduction 338 if override_persistent_reduction is not None 339 else self.should_use_persistent_reduction() 340 ) 341 self.no_x_dim = self.want_no_x_dim() 342 self.code_hash: Union[str, None] = None 343 344 # define this in a closure to make cache local to object 345 @functools.lru_cache(None) 346 def simplify_indexing(index: sympy.Expr): 347 index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) 348 for tree in self.range_trees: 349 index = self.combine_contiguous_dims(index, tree) 350 351 return self.combine_modular_indexing_pairs(index) 352 353 self.simplify_indexing = simplify_indexing 354 self.initialize_range_tree(pid_cache) 355 356 def want_no_x_dim(self): 357 return False 358 359 def initialize_range_tree(self, pid_cache): 360 no_r_dim = not self.inside_reduction or self.numels[-1] == 1 361 362 prefixes = "zyxr" 363 active_prefixes = prefixes[-len(self.numels) :] 364 365 grid_dims = "xyz" 366 if self.no_x_dim: 367 tensor_dims = "r" 368 elif no_r_dim: 369 tensor_dims = "xyz" 370 else: 371 tensor_dims = "xyzr" 372 373 tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes) 374 375 for i, prefix in enumerate(active_prefixes): 376 is_reduction = prefix == "r" 377 tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None 378 grid_dim = None if is_reduction else grid_dims.find(prefix) 379 index = i if grid_dim is None else grid_dim 380 self.range_trees.append( 381 IterationRangesRoot( 382 f"{prefix}index", 383 self.numels[i], 384 prefix, 385 index, 386 self, 387 pid_cache=pid_cache, 388 is_loop=is_reduction and not self.persistent_reduction, 389 tensor_dim=tensor_dim, 390 grid_dim=grid_dim, 391 has_zdim="z" in active_prefixes, 392 ) 393 ) 394 395 def finalize_indexing(self, indices: Sequence[sympy.Expr]): 396 """ 397 Hook called right before codegen with every index that will be 398 used in the fused kernel. 399 """ 400 401 def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): 402 prior = self.inside_reduction 403 self.inside_reduction = False 404 try: 405 return self.store(name, index, value) 406 finally: 407 self.inside_reduction = prior 408 409 def should_use_persistent_reduction(self) -> bool: 410 return False # defined in subclass 411 412 def var_ranges(self): 413 return dict( 414 itertools.chain.from_iterable( 415 tree.var_ranges.items() for tree in self.range_trees 416 ) 417 ) 418 419 def triton_tensor_ndim(self): 420 return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) 421 422 def indexing_size_str(self, i): 423 sizes = ["None"] * self.triton_tensor_ndim() 424 sizes[i] = ":" 425 return f"[{', '.join(sizes)}]" 426 427 def dense_size_list(self) -> List[str]: 428 sizes = ["1"] * self.triton_tensor_ndim() 429 for tree in self.range_trees: 430 if tree.tensor_dim is None: 431 continue 432 433 if tree.prefix != "r" or self.inside_reduction: 434 sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" 435 return sizes 436 437 def dense_size_str(self): 438 sizes = self.dense_size_list() 439 return f"[{', '.join(sizes)}]" 440 441 def combine_modular_indexing_pairs(self, index): 442 if not isinstance(index, ModularIndexing): 443 return index 444 x = index.args[0] 445 if (tree_node := self.range_tree_nodes.get(x)) is None: 446 return index 447 new_index = sympy_subs(index, {x: tree_node.expr}) 448 new_index = V.graph.sizevars.combine_modular_indexing_pairs(new_index) 449 # the index now contains xindex/etc, which is nonstandard, fix it up 450 return sympy_subs( 451 new_index, 452 { 453 tree_node.root.index_sym(): tree_node.root.lookup( 454 sympy.Integer(1), tree_node.root.numel 455 ).symbol() 456 }, 457 ) 458 459 def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): 460 if expand_res := V.graph.sizevars.expand_floor_div(index): 461 new_index, denominator = expand_res # type: ignore[misc] 462 return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator) 463 else: 464 return self._combine_contiguous_dims(index, tree) 465 466 def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): 467 """ 468 More aggressive simplification to merge contiguous dims 469 """ 470 if isinstance(index, (sympy.Integer, sympy.Symbol)): 471 return index 472 index_vars, sizes = tree.vars_and_sizes(index) 473 if len(sizes) <= 1: 474 return index 475 new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( 476 index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) 477 ) 478 if new_sizes == sizes: 479 return index 480 new_index_vars = tree.construct(new_sizes) 481 new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) 482 return new_index 483 484 def set_last_usage(self, nodes): 485 if not self.inside_reduction or self.persistent_reduction: 486 return 487 self.last_usage = OrderedSet( 488 itertools.chain.from_iterable( 489 n.last_usage for n in nodes if n is not EnableReduction 490 ) 491 ) 492 493 def disable_reduction(self): 494 should_flush = self.range_trees[-1].is_loop 495 496 @contextlib.contextmanager 497 def ctx(): 498 if self.numels[-1] == 1: 499 assert not self.inside_reduction 500 yield 501 return 502 if should_flush: 503 # calling codegen_body() will flush all the pending buffers 504 # and write out a reduction loop 505 self.codegen_body() 506 self.inside_reduction = False 507 try: 508 yield 509 if should_flush: 510 # flush out any code before opening the next loop 511 self.codegen_body() 512 finally: 513 self.inside_reduction = True 514 515 return ctx() 516 517 def set_ranges(self, *lengths): 518 assert len(lengths) == len(self.range_trees) 519 return [ 520 ranges.construct(length) 521 for length, ranges in zip(lengths, self.range_trees) 522 ] 523 524 @staticmethod 525 def _split_iteration_ranges( 526 groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] 527 ): 528 sv = V.graph.sizevars 529 new_ranges: List[List[sympy.Expr]] = [[] for _ in groups] 530 remaining = [sv.simplify(g) for g in groups] 531 var_count = itertools.count() 532 533 def add_range(i, expr): 534 expr = sv.simplify(expr) 535 if not sv.statically_known_multiple_of(remaining[i], expr): 536 raise CantSplit 537 # guard on the last item out 538 remaining[i] = FloorDiv(remaining[i], expr) 539 new_ranges[i].append(expr) 540 return next(var_count) 541 542 def make_combined(size, idx1, idx2): 543 def getter(flat_vars): 544 return size * flat_vars[idx1] + flat_vars[idx2] 545 546 return getter 547 548 return_getters_groups = [] 549 current_group = 0 550 for length_group in lengths: 551 return_getters = [] 552 for size in length_group: 553 if sv.statically_known_equals(size, 1): # type: ignore[arg-type] 554 return_getters.append(lambda _: sympy.Integer(0)) 555 continue 556 557 while current_group < len(remaining) and sv.statically_known_equals( 558 remaining[current_group], 1 # type: ignore[arg-type] 559 ): 560 # scroll to next group with remaining elements 561 current_group += 1 562 563 if current_group + 1 < len(remaining) and sv.statically_known_gt( 564 size, remaining[current_group] 565 ): 566 # need to break size in two 567 if not sv.statically_known_multiple_of( 568 size, remaining[current_group] 569 ): 570 raise CantSplit 571 size1 = remaining[current_group] 572 size2 = FloorDiv(size, remaining[current_group]) 573 return_getters.append( 574 make_combined( 575 size2, 576 add_range(current_group, size1), 577 add_range(current_group + 1, size2), 578 ) 579 ) 580 else: 581 return_getters.append( 582 operator.itemgetter(add_range(current_group, size)) 583 ) 584 return_getters_groups.append(return_getters) 585 586 assert all( 587 V.graph.sizevars.size_hint(s) == 1 for s in remaining 588 ), f"failed to set ranges {remaining} {lengths}" 589 590 return new_ranges, return_getters_groups 591 592 @classmethod 593 def is_compatible( 594 cls, groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] 595 ): 596 try: 597 cls._split_iteration_ranges(groups, lengths) 598 return True 599 except CantSplit: 600 return False 601 602 def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]): 603 """ 604 We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). 605 606 To do this we need to split up the iteration space of i0 into something like: 607 for i1 in s0: 608 for i2 in s1: 609 i0 = i1*s1 + i2 610 .... 611 612 This function matches and resplits lengths to the groups of 613 this kernel to enable tiled + non-tiled fusions. 614 """ 615 groups = [rt.numel for rt in self.range_trees] 616 if not self.inside_reduction: 617 groups[-1] = sympy.Integer(1) 618 619 if len(lengths) == len(self.range_trees) and all( 620 V.graph.sizevars.simplify(sympy_product(x) - g) == 0 621 for x, g in zip(lengths, groups) 622 ): 623 return self.set_ranges(*lengths) 624 625 new_ranges, return_getters_groups = self._split_iteration_ranges( 626 groups, lengths 627 ) 628 itervars = list(itertools.chain.from_iterable(self.set_ranges(*new_ranges))) 629 return [[fn(itervars) for fn in fns] for fns in return_getters_groups] 630 631 def is_indirect_indexing(self, index: sympy.Expr): 632 # tmpX means indirect indexing 633 return free_symbol_is_type(index, SymT.TMP) 634 635 def is_broadcasted(self, index: sympy.Expr): 636 # Note. This may not be correct when there is indirect indexing 637 if self.is_indirect_indexing(index): 638 return False 639 640 index_numels = [1] * len(self.numels) 641 for symbol in index.free_symbols: 642 if symbol not in self.range_tree_nodes: 643 # Non-iterated variables, e.g. strides 644 continue 645 entry = self.range_tree_nodes[symbol] # type: ignore[index] 646 assert isinstance(entry.parent, IterationRangesRoot) 647 index_numels[entry.parent.index] *= entry.length 648 649 # If the index variables only iterate over a subset of the kernel 650 # numels, then it must be broadcasted. 651 simplify = V.graph.sizevars.simplify 652 return any( 653 simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] 654 for idx_range, iter_range in zip(index_numels, self.numels) 655 ) 656 657 def index_to_str(self, index: sympy.Expr) -> str: 658 """ 659 Convert an index expr to a string that can be used in output code. 660 e.g. a sympy expression "s2" may actually appear as "ks1" in the generated kernel. 661 662 Index expressions often need to be passed in as arguments to the triton kernel. 663 Rename_indexing and codegen_indexing keep track of the needed indices and add 664 new parameters to the function signature. 665 """ 666 if isinstance(index, list): 667 return f"[{', '.join(map(self.index_to_str, index))}]" 668 return self.kexpr(self.rename_indexing(index)) # type: ignore[call-arg] 669 670 def prepare_indexing( 671 self, 672 index: sympy.Expr, 673 ): 674 index = self.simplify_indexing(index) 675 index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) 676 # if simple replacements didn't get rid of floor/ceil, try full subs 677 if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): 678 index = index.subs(V.graph.sizevars.precomputed_replacements) 679 # last resort, if no range vars are in the expr, hoist it 680 # TODO instead of trying to blindly find complicated exprs, we should hoist the 681 # inputs/outputs sizes and strides, but at the time indexing is generated 682 # kernel inputs and outputs are not set yet, we'd need a deeper refactor 683 # to do it this way 684 685 if len(index.atoms(sympy.ceiling)): 686 for a in index.atoms(sympy.ceiling): 687 # for nested exprs, atoms yields top level first (?) 688 # so if everything goes fine, lower level replacements will come up empty 689 symbols = a.free_symbols 690 if len(symbols) > 0 and all( 691 symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) 692 for s in symbols 693 ): 694 replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} 695 index = sympy_subs(index, replacements) 696 697 simp_index = self.simplify_indexing(index) 698 699 # Now that we are done simplifying we can unwrap Identity so that downstream handling 700 # for its contained expression will work. previously, tl.full wrapping of sympy.Integer 701 # would not occur 702 simp_index = ( 703 simp_index if not isinstance(simp_index, Identity) else simp_index.args[0] 704 ) 705 706 return self.codegen_indexing(simp_index) 707 708 def active_range_trees(self, reorder=False): 709 trees = [ 710 t for t in self.range_trees if t.prefix != "r" or self.inside_reduction 711 ] 712 if reorder and len(trees) > 1: 713 count = sum(t.prefix in "xyz" for t in trees) 714 assert "".join(t.prefix for t in trees[:count]) == "zyx"[-count:], [ 715 t.prefix for t in trees[:count] 716 ] 717 trees[:count] = reversed(trees[:count]) 718 return trees 719 720 def codegen_indexing(self, expr: sympy.Expr): 721 expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) 722 for sym in sorted(expr.free_symbols, key=str): 723 if sym in self.range_tree_nodes: 724 # if indexing expression is complicated, we precompute it on the host side 725 # and send the result as a kernel argument 726 replacements = {} 727 for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] 728 replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) 729 if len(replacements) > 0: 730 self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] 731 self.range_tree_nodes[sym].expr, replacements # type: ignore[index] 732 ) 733 self.range_tree_nodes[sym].codegen() # type: ignore[index] 734 return expr 735 736 def codegen_nan_check(self) -> None: 737 raise NotImplementedError("NYI: codegen_nan_check") 738 739 def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None: 740 raise NotImplementedError("NYI: call_kernel") 741 742 @contextlib.contextmanager 743 def mask_loads(self, mask, value): 744 """Context manager to add an additional mask to tl.load/store""" 745 prior = self._load_mask 746 prior_val = self._load_other 747 if prior: 748 mask = ops.logical_and(mask, prior) 749 750 mask = OpsWrapper._unwrap(mask) 751 self._load_mask = mask 752 self._load_other = value 753 try: 754 # TODO(jansel): do we need a reshape here? 755 yield mask 756 finally: 757 self._load_mask = prior 758 self._load_other = prior_val 759 760 def get_strides_of_load(self, index: sympy.Expr): 761 """ 762 This gets the stride of the index for each of the tiling variables 763 (technically, it does it at index 0) 764 765 For example, if 766 xindex = x0 + 512*x1 + 1024*r0 767 x0 = (xindex//512) 768 x1 = (xindex % 512) 769 r0 = rindex // 1024 770 771 this function would return 772 {xindex: 512, rindex: 1024} 773 """ 774 index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} 775 index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] 776 strides = {} 777 for range_tree in self.range_trees: 778 s = sympy_index_symbol(range_tree.name) 779 strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs( 780 index_in_tile_vars, {s: 0} 781 ) 782 return strides 783 784 @staticmethod 785 def _map_tuple_or_scalar(fn, value): 786 if isinstance(value, tuple): 787 return tuple(map(fn, value)) 788 return fn(value) 789 790 def estimate_kernel_num_bytes(self): 791 """ 792 Try the best to estimate the total size (in bytes) of the 793 kernel's inputs and outputs, which is used for estimating the memory 794 throughput of this kernel. This information is used for checking how 795 far we are from the peak memory bandwidth. It's important that 796 we want to avoid overestimating the sizes of the inputs and outputs, 797 because it can wrongfully give us a very large memory traffic value, 798 which may be even larger than the theoretical bandwidth and thus 799 become very misleading. This is particularly problematic for cases 800 where we slice some inputs. In those cases, we should only count 801 the size of the "slices" instead of the original inputs, because 802 only the slices contribute to the real memory traffic. 803 """ 804 nbytes = [] 805 ninplace_args = len(unique(self.args.inplace_buffers.values())) 806 _, call_args, _, _ = self.args.python_argdefs() 807 808 # For pointwise and reduction kernels, this is the upper-bound numels 809 # for the output buffer. 810 # FIXME: This is not exactly right for cases like below: 811 # def foo(tensor0, tensor1): 812 # x0 = narrow(tensor0) 813 # return cat(x0, tensor1) 814 # For this example, we will end up overestimate the size for the 815 # slice s0. Potentially, we could have precise inputs information 816 # if we maintained the original inputs of the Pointwise kernel created 817 # for the "cat". However, I think it might be a bit overwhelming that 818 # we add such complexity only for handling some particular cases for 819 # benchmarking. 820 out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels)) 821 for i, arg in enumerate(call_args): 822 # "buf" may be narrowed. In this case, the number of memory accesses 823 # should be estimated based on the reinterpreted layout. 824 # On the other hand, buf may be broadcasted. In this case, 825 # counting the size of the underline storage would give us 826 # a better estimation in terms of memory accesses. 827 if arg not in self.buf_accesses: 828 nbytes.append(0) 829 continue 830 arg_numel = V.graph.get_numel(arg) 831 buf_size = V.graph.sizevars.size_hint(arg_numel) 832 if buf_size > out_numel: 833 # This arg points to a buf that has been sliced. 834 # We need to count each individual slice to have 835 # a better estimation. 836 indices: OrderedSet[Any] = OrderedSet() 837 no_index_dep_count = 0 838 for dep in self.buf_accesses[arg]: 839 if isinstance(dep, (StarDep, WeakDep)): 840 indices.add(f"no_index_dep_{no_index_dep_count}") 841 no_index_dep_count += 1 842 else: 843 indices.add(dep.index) 844 numel = len(indices) * out_numel 845 else: 846 numel = buf_size 847 dtype = V.graph.get_dtype(arg) 848 dtype_size = get_dtype_size(dtype) 849 nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) 850 return sum(nbytes) 851 852 def warn_mix_layout(self, kernel_name): 853 """ 854 Print message if the kernel have mixed layout inputs. 855 Only care about 4D tensor for now. 856 """ 857 if ( 858 len(self.args.input_buffers) == 1 859 and len(self.args.output_buffers) == 1 860 and len(self.args.inplace_buffers) == 0 861 ): 862 # even if input buffer and output buffer have different layout, 863 # this can be a layout conversion kernel. No need to warn for 864 # the mix layouts. 865 return 866 867 argdefs, call_args, signature, _ = self.args.python_argdefs() 868 uniform_stride_order = None 869 for arg_name in call_args: 870 buf = V.graph.try_get_buffer(arg_name) 871 if buf and len(buf.layout.size) == 4: 872 # ignore the tensor if only 1 dimension is non-zero 873 if len([x for x in buf.layout.size if x == 1]) == 3: 874 continue 875 stride_order = ir.get_stride_order(buf.layout.stride) 876 if uniform_stride_order is None: 877 uniform_stride_order = stride_order 878 elif uniform_stride_order != stride_order: 879 msg = yellow_text( 880 f"Expected stride order {uniform_stride_order}, but found stride order" 881 + f" {stride_order} for kernel {kernel_name}" 882 ) 883 log.warning(msg) 884 885 stride_order_list = [ 886 ir.get_stride_order(V.graph.get_buffer(name).layout.stride) 887 if V.graph.try_get_buffer(name) 888 else None 889 for name in call_args 890 ] 891 size_list = [ 892 V.graph.get_buffer(name).layout.size 893 if V.graph.try_get_buffer(name) 894 else None 895 for name in call_args 896 ] 897 source_list = [ 898 "GraphInput" 899 if name in V.graph.graph_inputs 900 else "IntermediateBuffer" 901 if name in V.graph.name_to_buffer 902 else None 903 for name in call_args 904 ] 905 906 msg = yellow_text( 907 f" param names {argdefs}\n buf names {call_args}\n strides {stride_order_list}" 908 + f"\n sizes {size_list}\n sources {source_list}\n" 909 ) 910 log.warning(msg) 911 return 912 msg = green_text( 913 f"All the inputs for the triton kernel {kernel_name} have uniform layout" 914 ) 915 log.warning(msg) 916 917 def welford_reduce_fallback(self, dtype, value): 918 sum_ = ops.reduction(dtype, dtype, "sum", value) 919 self.inside_reduction = False 920 rnumel = ops.index_expr(self.numels[-1], dtype) 921 mean = ops.truediv(sum_, rnumel) 922 923 self.inside_reduction = True 924 dx = ops.sub(value, mean) 925 dx2 = ops.mul(dx, dx) 926 m2 = ops.reduction(dtype, dtype, "sum", dx2) 927 return OpsWrapper._unwrap((mean, m2, rnumel)) 928 929 def codegen_kernel(self): 930 raise NotImplementedError 931 932 def codegen_body(self): 933 pass 934 935 def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): 936 pass 937 938 939class SIMDScheduling(BaseScheduling): 940 kernel_type = SIMDKernel # override in subclass 941 int32_type = "torch.int32" 942 int64_type = "torch.int64" 943 944 def __init__(self, scheduler) -> None: 945 super().__init__() 946 self.scheduler = scheduler 947 948 def group_fn(self, sizes): 949 return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) 950 951 def can_fuse(self, node1, node2): 952 """ 953 Hook called by Scheduler to determine if the Triton backend 954 can fuse node1 and node2. These nodes might already be 955 FusedSchedulerNodes. 956 """ 957 if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( 958 node2, scheduler.ForeachKernelSchedulerNode 959 ): 960 return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) 961 962 _, (numel1, rnumel1) = node1.group 963 _, (numel2, rnumel2) = node2.group 964 why = WhyNoFuse(node1, node2) 965 966 if node1.is_split_scan() and not node2.is_split_scan(): 967 if node2.is_reduction(): 968 why("Split scan cannot fuse with reductions") 969 elif node2.is_split_scan() and not node1.is_split_scan(): 970 if node1.is_reduction(): 971 why("Split scan cannot fuse with reductions") 972 973 if node1.is_reduction() and node2.is_reduction(): 974 reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 975 if not reduction_can_fuse: 976 why( 977 "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", 978 numel1, 979 numel2, 980 rnumel1, 981 rnumel2, 982 ) 983 return reduction_can_fuse 984 985 if not node1.is_reduction() and not node2.is_reduction(): 986 if not (numel1 == numel2 and rnumel1 == rnumel2): 987 why( 988 "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", 989 numel1, 990 numel2, 991 rnumel1, 992 rnumel2, 993 ) 994 return False 995 996 if node1.is_template(): 997 # Only allow fusion for TritonTemplates for now. 998 # Fusion for CUDATemplates are not supported. 999 is_triton_template = isinstance(node1.node, TritonTemplateBuffer) 1000 if not is_triton_template: 1001 why("node1 is not TritonTemplateBuffer") 1002 return is_triton_template 1003 1004 # check for a bad combined tiling 1005 tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) 1006 tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) 1007 tiling3 = self.select_tiling( 1008 node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 1009 ) 1010 if config.triton.tiling_prevents_pointwise_fusion: 1011 cond = True 1012 if len(tiling1) > 2: 1013 if len(tiling2) > 2: 1014 cond = tiling1 == tiling2 == tiling3 1015 else: 1016 cond = tiling1 == tiling3 1017 elif len(tiling2) > 2: 1018 cond = tiling2 == tiling3 1019 if not cond: 1020 why( 1021 "tiling mismatch (%s, %s, %s)", 1022 tiling1, 1023 tiling2, 1024 tiling3, 1025 ) 1026 return False 1027 1028 return True 1029 1030 if not node1.is_reduction() and node2.is_reduction(): 1031 assert rnumel1 == 1 and rnumel2 != 1 1032 if numel1 == numel2 * rnumel2: 1033 if not all( 1034 SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges()) 1035 for n in node1.get_nodes() 1036 ): 1037 why("nodes numel/rnumel incompatibility") 1038 return False 1039 if ( 1040 config.triton.tiling_prevents_reduction_fusion 1041 and not node1.is_template() 1042 ): 1043 is_reduction_tiling_valid = self.select_tiling( 1044 node1.get_nodes(), numel1 1045 ) in ( 1046 (numel1, 1), 1047 (numel2, rnumel2, 1), 1048 ) 1049 if not is_reduction_tiling_valid: 1050 why("invalid tiling for reduction") 1051 return is_reduction_tiling_valid 1052 return True 1053 1054 if numel1 != numel2: 1055 why("nodes numel incompatibility") 1056 return numel1 == numel2 1057 1058 assert node1.is_reduction() and not node2.is_reduction() 1059 # swap args to hit the case above 1060 return self.can_fuse_horizontal(node2, node1) 1061 1062 can_fuse_vertical = can_fuse 1063 can_fuse_horizontal = can_fuse 1064 1065 def generate_node_schedule(self, nodes, numel, rnumel): 1066 node_schedule: List[Any] = [] 1067 done: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() 1068 # Writes with a reduced shape, meaning they are only present once the 1069 # reduction loop has ended 1070 not_ready_yet_nodes: OrderedSet[str] = OrderedSet() 1071 1072 def fits_in_main_body(n): 1073 _, (node_numel, node_rnumel) = n.group 1074 return (node_numel == numel and node_rnumel == rnumel) or ( 1075 node_numel == numel * rnumel and node_rnumel == 1 1076 ) 1077 1078 def fits_outside_reduction(n): 1079 _, (node_numel, node_rnumel) = n.group 1080 return node_numel == numel and node_rnumel == 1 and rnumel != 1 1081 1082 def schedule_node_in_loop(n): 1083 done.add(n) 1084 node_schedule.append(n) 1085 # A scan is modelled as a reduction in the scheduler but has a 1086 # full sized output that can be used inside the loop body 1087 if ( 1088 n.is_reduction() 1089 and isinstance(n, scheduler.SchedulerNode) 1090 and isinstance(n.node, ir.ComputedBuffer) 1091 and not isinstance(n.node.data, ir.Scan) 1092 ): 1093 not_ready_yet_nodes.add(n.get_name()) 1094 1095 @contextlib.contextmanager 1096 def end_current_reduction_loop(): 1097 if node_schedule and node_schedule[-1] is EnableReduction: 1098 node_schedule.pop() 1099 else: 1100 node_schedule.append(DisableReduction) 1101 yield 1102 node_schedule.append(EnableReduction) 1103 not_ready_yet_nodes.clear() 1104 1105 def requires_closing_previous_reduction(node, node_schedule): 1106 if rnumel == 1: 1107 return False 1108 if not not_ready_yet_nodes & node.ancestors: 1109 return False 1110 assert node_schedule and not isinstance( 1111 node_schedule[-1], (EnableReduction, DisableReduction) 1112 ) 1113 return bool(not_ready_yet_nodes) 1114 1115 for index, node in enumerate(nodes): 1116 if node in done: 1117 continue 1118 done.add(node) 1119 1120 if fits_in_main_body(node): 1121 if requires_closing_previous_reduction(node, node_schedule): 1122 with end_current_reduction_loop(): 1123 pass # need to start a new reduction loop 1124 1125 schedule_node_in_loop(node) 1126 elif fits_outside_reduction(node): 1127 with end_current_reduction_loop(): 1128 node_schedule.append(node) 1129 else: 1130 raise NotImplementedError( 1131 f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" 1132 ) 1133 1134 return node_schedule 1135 1136 def codegen_node( 1137 self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] 1138 ): 1139 """ 1140 Given a set of pre-fused nodes, generate a Triton kernel. 1141 """ 1142 1143 nodes: List[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] 1144 1145 _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group 1146 1147 node_schedule = self.generate_node_schedule(nodes, numel, rnumel) 1148 buf_accesses = collections.defaultdict(list) 1149 for node in nodes: 1150 for access in node.read_writes.reads | node.read_writes.writes: 1151 buf_accesses[access.name].append(access) 1152 1153 schedule_log.debug("Schedule:\n %s", node_schedule) 1154 1155 return self.codegen_node_schedule(node_schedule, buf_accesses, numel, rnumel) 1156 1157 @staticmethod 1158 def reduction_hint(node): 1159 assert node.is_reduction() 1160 if all( 1161 dep.is_contiguous() 1162 for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) 1163 ): 1164 return ReductionHint.INNER 1165 else: 1166 return node.node.data.reduction_hint 1167 1168 @staticmethod 1169 def can_use_32bit_indexing( 1170 numel: sympy.Expr, buffers: Iterable[Union[ir.Buffer, ir.TensorBox]] 1171 ) -> bool: 1172 int_max = torch.iinfo(torch.int32).max 1173 size_hint = V.graph.sizevars.size_hint 1174 has_hint = V.graph.sizevars.shape_env.has_hint 1175 1176 def within_32bit(e): 1177 # Allow for unhinted e as long as we can still statically prove 1178 # (e.g., via ValueRanges) that it is still in bounds 1179 if V.graph.sizevars.is_expr_static_and_true(e <= int_max): 1180 return True 1181 # Otherwise, the hint MUST exist and be in range 1182 return has_hint(e) and size_hint(e) <= int_max 1183 1184 if not within_32bit(numel): 1185 return False 1186 1187 # Any use of a MultiOutputLayout will create a buffer with a 1188 # Layout whose sizes are accounted for 1189 buf_sizes = [ 1190 buf.get_layout().storage_size() 1191 for buf in buffers 1192 if not isinstance(buf.get_layout(), ir.MultiOutputLayout) 1193 ] 1194 1195 if not all(within_32bit(size) for size in buf_sizes): 1196 return False 1197 1198 # Only install guards for 32-bit indexing as there is no correctness 1199 # issue with using 64-bit for everything 1200 V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] 1201 for size in buf_sizes: 1202 V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] 1203 return True 1204 1205 @classmethod 1206 def select_index_dtype(cls, node_schedule, numel, reduction_numel): 1207 # Gather all used buffer names 1208 buffer_names: OrderedSet[str] = OrderedSet() 1209 for node in node_schedule: 1210 if not isinstance(node, scheduler.BaseSchedulerNode): 1211 continue 1212 1213 buffer_names.update(node.get_buffer_names()) 1214 buffer_names.update(node.used_buffer_names()) 1215 1216 # Get buffers objects 1217 1218 def _get_buffer(name: str) -> Union[ir.Buffer, ir.TensorBox]: 1219 buf = V.graph.get_buffer(name) 1220 if buf is None: 1221 raise RuntimeError(f"Failed to find buffer matching name {name}") 1222 return buf 1223 1224 buffers = [V.graph.get_buffer(name) for name in buffer_names] 1225 1226 # In theory we can separately check xnumel and rnumel are <= int_max 1227 # but some indexers do use the full linear index so we need to be 1228 # conservative here. 1229 total_numel = numel * reduction_numel 1230 1231 if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers): 1232 return cls.int32_type 1233 return cls.int64_type 1234 1235 def has_non_contiguous_pw_in_reduction_kernel(self, node_schedule, numel, rnumel): 1236 pointwise_nodes = list( 1237 filter( 1238 lambda n: n not in (EnableReduction, DisableReduction) 1239 and not n.is_reduction() 1240 and n.group[1][0] == numel * rnumel, 1241 node_schedule, 1242 ) 1243 ) 1244 for node in pointwise_nodes: 1245 # An index can be an integer when loading a random seed. 1246 if not all( 1247 not isinstance(dep, MemoryDep) 1248 or dep.is_contiguous() 1249 or isinstance(dep.index, (sympy.Integer, int)) 1250 or dep.stride1_for_last_dim() 1251 for dep in itertools.chain( 1252 node.read_writes.reads, node.read_writes.writes 1253 ) 1254 ): 1255 return True 1256 return False 1257 1258 def get_kernel_args(self, node_schedule, numel, reduction_numel): 1259 reductions = list( 1260 filter( 1261 lambda n: n not in (EnableReduction, DisableReduction) 1262 and n.is_reduction(), 1263 node_schedule, 1264 ) 1265 ) 1266 if len(reductions) > 0: 1267 hints = [self.reduction_hint(n) for n in reductions] 1268 if hints.count(hints[0]) == len(hints): 1269 reduction_hint_val = hints[0] 1270 else: 1271 reduction_hint_val = ReductionHint.DEFAULT 1272 1273 if ( 1274 reduction_hint_val == ReductionHint.INNER 1275 and self.has_non_contiguous_pw_in_reduction_kernel( 1276 node_schedule, numel, reduction_numel 1277 ) 1278 ): 1279 reduction_hint_val = ReductionHint.DEFAULT 1280 else: 1281 reduction_hint_val = ReductionHint.DEFAULT 1282 1283 mutations: OrderedSet[str] = OrderedSet() 1284 for node in node_schedule: 1285 if node in (DisableReduction, EnableReduction): 1286 continue 1287 1288 for buf in node.get_outputs(): 1289 mutations.update(buf.get_mutations()) 1290 1291 index_dtype = self.select_index_dtype(node_schedule, numel, reduction_numel) 1292 1293 return reduction_hint_val, mutations, index_dtype 1294 1295 def codegen_node_schedule( 1296 self, node_schedule, buf_accesses, numel, reduction_numel 1297 ): 1298 from torch._inductor.codegen.triton_split_scan import TritonSplitScanKernel 1299 1300 tiled_groups = self.select_tiling(node_schedule, numel, reduction_numel) 1301 ( 1302 reduction_hint_val, 1303 mutations, 1304 index_dtype, 1305 ) = self.get_kernel_args(node_schedule, numel, reduction_numel) 1306 1307 is_split_scan = any( 1308 isinstance(node, BaseSchedulerNode) and node.is_split_scan() 1309 for node in node_schedule 1310 ) 1311 kernel_type: type = self.kernel_type 1312 if is_split_scan and issubclass(TritonSplitScanKernel, kernel_type): 1313 kernel_type = TritonSplitScanKernel 1314 1315 kernel_args = tiled_groups 1316 kernel_kwargs = dict( 1317 reduction_hint=reduction_hint_val, 1318 mutations=mutations, 1319 index_dtype=index_dtype, 1320 ) 1321 1322 def _node_has_sort(node): 1323 if node in (EnableReduction, DisableReduction): 1324 return False 1325 1326 sort_nodes = node._body.root_block.graph.find_nodes( 1327 op="call_method", target="sort" 1328 ) 1329 return bool(sort_nodes) 1330 1331 # ops.sort only works with persistent reduction, and is not bandwidth bound anyway 1332 # so taking the hit of non-coalesced loads is okay 1333 has_sort = any(_node_has_sort(node) for node in node_schedule) 1334 if has_sort: 1335 kernel_kwargs["override_persistent_reduction"] = True 1336 1337 kernel = kernel_type( 1338 *kernel_args, 1339 **kernel_kwargs, 1340 ) 1341 kernel.buf_accesses = buf_accesses 1342 1343 kernel2: Optional[SIMDKernel] = None 1344 if kernel.persistent_reduction and config.triton.multi_kernel and not has_sort: 1345 kernel2 = self.kernel_type( 1346 *kernel_args, 1347 **kernel_kwargs, 1348 override_persistent_reduction=False, 1349 ) 1350 self.codegen_node_schedule_with_kernel(node_schedule, kernel2) 1351 with V.set_kernel_handler(kernel2): 1352 src_code2 = kernel2.codegen_kernel() 1353 kernel_name2 = self.define_kernel(src_code2, node_schedule, kernel) 1354 kernel2.kernel_name = kernel_name2 1355 kernel2.code_hash = code_hash(src_code2) 1356 1357 # Keep buffers needed by the non-persistent reduction so both 1358 # kernels have the same arguments 1359 kernel.must_keep_buffers = set(kernel2.must_keep_buffers) 1360 1361 self.codegen_node_schedule_with_kernel(node_schedule, kernel) 1362 1363 with V.set_kernel_handler(kernel): 1364 src_code = kernel.codegen_kernel() 1365 1366 kernel_name = self.define_kernel(src_code, node_schedule, kernel) 1367 log.debug("Generating kernel code with kernel_name: %s", kernel_name) 1368 kernel.kernel_name = kernel_name 1369 kernel.code_hash = code_hash(src_code) 1370 1371 final_kernel = MultiKernel([kernel, kernel2]) if kernel2 is not None else kernel 1372 1373 with V.set_kernel_handler(final_kernel): 1374 for node in node_schedule: 1375 if node not in (EnableReduction, DisableReduction): 1376 node.mark_run() 1377 1378 self.codegen_comment(node_schedule) 1379 final_kernel.call_kernel(final_kernel.kernel_name) 1380 1381 if config.nan_asserts: 1382 final_kernel.codegen_nan_check() 1383 if config.warn_mix_layout: 1384 final_kernel.warn_mix_layout(kernel_name) 1385 1386 V.graph.removed_buffers |= final_kernel.removed_buffers 1387 V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove 1388 1389 if ( 1390 V.graph.wrapper_code.supports_intermediate_hooks 1391 and config.generate_intermediate_hooks 1392 ): 1393 # Not every node in the schedule will actually be live on output; 1394 # we can't check dead buffers. 1395 live_outs = kernel.args.live_output_buffers() 1396 for node in node_schedule: 1397 if not isinstance(node, scheduler.BaseSchedulerNode): 1398 continue 1399 name = node.get_name() 1400 if name not in live_outs: 1401 continue 1402 assert node.node is not None 1403 origin_node = node.node.get_origin_node() 1404 if origin_node is not None: 1405 counters["inductor"]["intermediate_hooks"] += 1 1406 V.graph.wrapper_code.writeline( 1407 f"run_intermediate_hooks({origin_node.name!r}, {name})" 1408 ) 1409 1410 self.scheduler.free_buffers() 1411 1412 def codegen_node_schedule_with_kernel(self, node_schedule, kernel): 1413 def current_reduction_nodes(nodes): 1414 return itertools.takewhile(lambda n: n is not DisableReduction, nodes) 1415 1416 with kernel: 1417 stack = contextlib.ExitStack() 1418 kernel.set_last_usage(current_reduction_nodes(node_schedule)) 1419 all_indexing = {} 1420 1421 # First pass to collect indexing and decide inplace updates 1422 for node in node_schedule: 1423 if node is DisableReduction: 1424 stack.enter_context(kernel.disable_reduction()) 1425 elif node is EnableReduction: 1426 stack.close() 1427 else: 1428 node.decide_inplace_update() 1429 index_vars = kernel.split_and_set_ranges(node.get_ranges()) 1430 all_indexing.update( 1431 dict.fromkeys( 1432 node._body.indexing_from_args(index_vars).values() 1433 ) 1434 ) 1435 1436 kernel.finalize_indexing(all_indexing.keys()) 1437 1438 # Second pass to do codegen 1439 for i, node in enumerate(node_schedule): 1440 if node is DisableReduction: 1441 stack.enter_context(kernel.disable_reduction()) 1442 elif node is EnableReduction: 1443 stack.close() 1444 kernel.set_last_usage(current_reduction_nodes(node_schedule[i:])) 1445 else: 1446 # TODO - use split ranges ? 1447 indexing_dtype_strength_reduction(node._body) 1448 index_vars = kernel.split_and_set_ranges(node.get_ranges()) 1449 node.codegen(index_vars) 1450 1451 def codegen_template( 1452 self, template_node, epilogue_nodes, only_gen_src_code=False 1453 ) -> Optional[str]: 1454 """ 1455 Codegen a triton template 1456 1457 If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper 1458 """ 1459 _, (numel, rnumel) = template_node.group 1460 assert rnumel == 1 1461 kernel, render = template_node.node.make_kernel_render(template_node.node) 1462 with kernel: 1463 if not only_gen_src_code: 1464 for node in [template_node, *epilogue_nodes]: 1465 node.mark_run() 1466 partial_code = render() 1467 with kernel.set_subgraph_body("<STORE_OUTPUT>"): 1468 for node in epilogue_nodes: 1469 node.codegen(kernel.split_and_set_ranges(node.get_ranges())) 1470 1471 if not isinstance(partial_code, str): 1472 partial_code.finalize_hook("<DEF_KERNEL>") 1473 partial_code.finalize_hook("<ARGDEFS>", strict=False) 1474 # finalize must be called after adding epilogue above 1475 with V.set_kernel_handler(kernel): 1476 # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. 1477 with kernel.set_subgraph_body("<STORE_OUTPUT>"): 1478 if isinstance(partial_code, str): 1479 src_code = partial_code 1480 else: 1481 partial_code.finalize_hook("<STORE_OUTPUT>") 1482 src_code = partial_code.code 1483 node_schedule = [template_node, *epilogue_nodes] 1484 1485 if config.benchmark_kernel: 1486 num_gb = kernel.estimate_kernel_num_bytes() / 1e9 1487 grid_args = V.graph.sizevars.size_hints(kernel.call_sizes) 1488 assert kernel.meta is not None, "meta is None" 1489 grid = kernel.grid_fn(*grid_args, kernel.meta) 1490 src_code = ( 1491 f"{kernel.imports_for_benchmark_kernel()}\n" 1492 f"{src_code}\n" 1493 f"{kernel.codegen_kernel_benchmark(num_gb, grid).getvalue()}" 1494 ) 1495 1496 if only_gen_src_code: 1497 return src_code 1498 1499 kernel_name = self.define_kernel(src_code, node_schedule, kernel) 1500 1501 self.codegen_comment(node_schedule) 1502 kernel.call_kernel(kernel_name, template_node.node) 1503 1504 V.graph.removed_buffers |= kernel.removed_buffers 1505 V.graph.inplaced_to_remove |= kernel.inplaced_to_remove 1506 self.scheduler.free_buffers() 1507 return None 1508 1509 def codegen_sync(self): 1510 V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) 1511 1512 def generate_combo_kernel_code( 1513 self, 1514 subkernel_nodes: List[BaseSchedulerNode], 1515 custom_part_algorithm: bool, 1516 enable_autotune: bool, 1517 mixed_sizes: bool, 1518 only_gen_src_code: bool = False, 1519 ) -> List[Tuple[str, Any, Any]]: 1520 from .triton_combo_kernel import ComboKernel 1521 1522 fused_node_lists = [node.get_nodes() for node in subkernel_nodes] 1523 subkernel_map, node_schedule_map = {}, {} 1524 for pn, nodes in zip(subkernel_nodes, fused_node_lists): 1525 _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group 1526 node_schedule = self.generate_node_schedule(nodes, numel, rnumel) 1527 tiled_groups = self.select_tiling(node_schedule, numel, rnumel) 1528 node_schedule_map[pn] = node_schedule, tiled_groups, numel, rnumel 1529 ( 1530 reduction_hint_val, 1531 mutations, 1532 index_dtype, 1533 ) = self.get_kernel_args(node_schedule, numel, rnumel) 1534 subkernel_map[pn] = ComboKernel.create_triton_kernel( 1535 *tiled_groups, 1536 reduction_hint=reduction_hint_val, 1537 mutations=mutations, 1538 index_dtype=index_dtype, 1539 optimize_mask=not mixed_sizes, 1540 ) 1541 1542 partitions = ComboKernel.horizontal_partition( 1543 nodes=subkernel_nodes, 1544 triton_scheduling=self, 1545 custom_algorithm=custom_part_algorithm, 1546 kernel_map=subkernel_map, 1547 node_info_map=node_schedule_map, 1548 ) 1549 log.debug( 1550 "ComboKernels: %d nodes partitioned into %s groups", 1551 len(subkernel_nodes), 1552 [len(p) for p in partitions], 1553 ) 1554 kernel_code_list = [] 1555 for node_group in partitions: 1556 fused_node_lists = [node.get_nodes() for node in node_group] 1557 kernel = ComboKernel( 1558 enable_autotune=enable_autotune, 1559 mixed_sizes=mixed_sizes, 1560 ) 1561 1562 for pn, nodes in zip(node_group, fused_node_lists): 1563 if only_gen_src_code: 1564 # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. 1565 for n in nodes: 1566 n.last_usage = OrderedSet() 1567 self.codegen_node_schedule_with_kernel( 1568 node_schedule_map[pn][0], 1569 kernel.create_sub_kernel(subkernel_map[pn]), 1570 ) 1571 subkernel = subkernel_map[pn] 1572 node_schedule = node_schedule_map[pn][0] 1573 if not only_gen_src_code: 1574 with V.set_kernel_handler(subkernel): # type: ignore[call-arg] 1575 for node in node_schedule: 1576 if node not in (EnableReduction, DisableReduction): 1577 node.mark_run() 1578 V.graph.removed_buffers |= subkernel.removed_buffers 1579 V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove 1580 1581 src_code = kernel.codegen_kernel() 1582 kernel_code_list.append((src_code, kernel, node_group)) 1583 return kernel_code_list 1584 1585 def codegen_combo_kernel(self, combo_kernel_node): 1586 subkernel_nodes = combo_kernel_node.get_subkernel_nodes() 1587 custom_part_algorithm = combo_kernel_node.use_custom_partition_algo 1588 enable_autotune = combo_kernel_node.enable_autotune 1589 mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or ( 1590 config.combo_kernel_allow_mixed_sizes == 1 and custom_part_algorithm 1591 ) 1592 1593 kernel_code_list = self.generate_combo_kernel_code( 1594 subkernel_nodes, custom_part_algorithm, enable_autotune, mixed_sizes 1595 ) 1596 1597 for src_code, kernel, _ in kernel_code_list: 1598 kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) 1599 self.codegen_comment([combo_kernel_node]) 1600 log.debug("ComboKernels: generated kernel %s.", kernel_name) 1601 kernel.call_kernel(V.graph.wrapper_code, kernel_name) 1602 1603 self.scheduler.free_buffers() 1604 1605 @staticmethod 1606 @functools.lru_cache(32) 1607 def candidate_tilings(node): 1608 ranges, reduction_ranges = node.get_ranges() 1609 if len(ranges) <= 1: 1610 return () 1611 1612 rw = node.pointwise_read_writes() 1613 assert len(rw.range_vars) == len(ranges) 1614 1615 # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads 1616 # that need to access the entire tensor; they don't contribute read indexing 1617 # information (and practically, they don't have dep.index so they can't be used 1618 # for stride_hints below 1619 dep_sources = [rw.reads, rw.writes] 1620 assert all( 1621 isinstance(dep, (MemoryDep, StarDep)) 1622 for dep in itertools.chain.from_iterable(dep_sources) 1623 ) 1624 deps = [ 1625 dep 1626 for dep in itertools.chain.from_iterable(dep_sources) 1627 if dep.name not in V.graph.removed_buffers and isinstance(dep, MemoryDep) 1628 ] 1629 write_names = {dep.name for dep in rw.writes} 1630 1631 tilings: List[CandidateTiling] = [] 1632 1633 for dep in deps: 1634 strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) 1635 assert len(strides) == len(ranges) 1636 try: 1637 split = strides.index(1) + 1 1638 if split == len(ranges): 1639 continue 1640 if all(s == 0 for s in strides[split:]): 1641 # if this is a broadcasted tensor and all dimensions after split are broadcast, 1642 # this is not a real split 1643 continue 1644 1645 except ValueError: 1646 continue 1647 tiled_groups = ( 1648 V.graph.sizevars.simplify(sympy_product(ranges[:split])), 1649 V.graph.sizevars.simplify(sympy_product(ranges[split:])), 1650 ) 1651 # score by number of elements 1652 score = V.graph.sizevars.size_hint( 1653 sympy_product( 1654 size for size, stride in zip(ranges, strides) if stride != 0 1655 ) 1656 ) 1657 if dep.name in write_names: 1658 # ngimel said contiguous writes is more important than reads 1659 score *= 2 1660 if CandidateTiling.is_good_size(tiled_groups[0]): 1661 score *= 2 1662 if CandidateTiling.is_good_size(tiled_groups[1]): 1663 score *= 2 1664 1665 if ( 1666 V.graph.sizevars.size_hint( 1667 score - sympy_product(itertools.chain(ranges, reduction_ranges)) 1668 ) 1669 >= 0 1670 ): 1671 tilings.append(CandidateTiling(tiled_groups, score, dep.name)) 1672 return tilings 1673 1674 @classmethod 1675 def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)): 1676 """ 1677 Heuristics to decide how to tile kernels. 1678 Currently, we tile based on stride-1 dimensions. 1679 1680 Returns: 1681 `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` 1682 1683 """ 1684 if reduction_numel != 1 or config.triton.max_tiles <= 1: 1685 # TODO(jansel): should we tile reductions? 1686 # do perf hint here if stride-1 dim is not being reduced 1687 if perf_hint_log.level <= logging.WARNING: 1688 for node in EnableReduction.filter(node_schedule): 1689 if len(cls.candidate_tilings(node)) > 0: 1690 perf_hint_log.info("reduction over non-contiguous dims") 1691 break 1692 return (numel, reduction_numel) 1693 1694 seen_names: OrderedSet[str] = OrderedSet() 1695 candidate_tiles: Counter[Any] = collections.Counter() 1696 for node in EnableReduction.filter(node_schedule): 1697 for tiling in cls.candidate_tilings(node): 1698 if tiling.name in seen_names: 1699 continue 1700 seen_names.add(tiling.name) 1701 candidate_tiles[tiling.tiling] += tiling.score 1702 1703 ranked_tilings = [tiling for tiling, score in candidate_tiles.most_common()] 1704 1705 if config.triton.max_tiles >= 3: 1706 # Consider adding a third dimension of tiling, but only 1707 # when a1 is a multiple of b1; otherwise, you have a lot 1708 # of stragglers which is annoying to generate code for. 1709 # 1710 # NB: More than three max tiles is not enabled by default. 1711 1712 # Add one 3D tiling choice 1713 for i in range(1, len(ranked_tilings)): 1714 a0, a1 = ranked_tilings[0] 1715 b0, b1 = ranked_tilings[i] 1716 if V.graph.sizevars.size_hint(a1 - b1) == 0: 1717 continue 1718 if V.graph.sizevars.size_hint(a1 - b1) < 0: 1719 # swap so a0 is bigger 1720 a0, a1 = ranked_tilings[i] 1721 b0, b1 = ranked_tilings[0] 1722 assert V.graph.sizevars.size_hint(a1 - b1) > 0 1723 if V.graph.sizevars.statically_known_multiple_of(a1, b1): 1724 tiling = (a0, FloorDiv(a1, b1), b1) 1725 ranked_tilings = [tiling] + ranked_tilings 1726 break # only 1 choice for now 1727 1728 if len(ranked_tilings) > 1: 1729 perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) 1730 1731 # Optionally, prefer tiling into as many dimensions as possible. 1732 if config.triton.prefer_nd_tiling: 1733 # Get candidate tilings from the node ranges. 1734 node_ranges = [ 1735 node.get_ranges()[0] 1736 for node in EnableReduction.filter(node_schedule) 1737 if isinstance(node, scheduler.SchedulerNode) 1738 ] 1739 new_tilings: OrderedSet[Tuple[sympy.Expr]] = OrderedSet() 1740 for node_range in node_ranges: 1741 # Collapse leading dims, to fit in the maximum dimensionality. 1742 num_leading_dims = max(0, len(node_range) - config.triton.max_tiles) 1743 first_trailing_dim = num_leading_dims + 1 1744 collapsed_leading_dim = sympy_product(node_range[:first_trailing_dim]) 1745 tiling = [collapsed_leading_dim] + list(node_range[first_trailing_dim:]) 1746 new_tilings.add(tuple(tiling)) 1747 1748 # Rank tilings by the number of dimensions. E.g., prefer 2D to 1D. 1749 # Since this is a stable sort, ties are broken by schedule order. 1750 ranked_new_tilings = sorted(new_tilings, key=len, reverse=True) 1751 ranked_tilings = ranked_new_tilings + ranked_tilings 1752 1753 for tiled_groups in ranked_tilings: 1754 new_groups = (*tiled_groups, reduction_numel) 1755 if all( 1756 SIMDKernel.is_compatible(new_groups, node.get_ranges()) 1757 for node in node_schedule 1758 if isinstance(node, scheduler.SchedulerNode) 1759 ): 1760 return new_groups 1761 1762 return (numel, reduction_numel) 1763 1764 def flush(self): 1765 pass 1766 1767 def ready_to_flush(self) -> bool: 1768 return False 1769 1770 def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): 1771 @dataclasses.dataclass 1772 class LastUsageHolder: 1773 n: Any 1774 last_usage: Any 1775 1776 def __del__(self) -> None: 1777 self.n.last_usage = self.last_usage 1778 1779 last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes] 1780 1781 # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. 1782 for n in nodes: 1783 n.last_usage = OrderedSet() 1784 1785 if not nodes[0].is_template(): 1786 _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group 1787 node_schedule = self.generate_node_schedule(nodes, numel, rnumel) 1788 1789 tiled_groups = self.select_tiling(node_schedule, numel, rnumel) 1790 reduction_hint_val, mutations, index_dtype = self.get_kernel_args( 1791 node_schedule, numel, rnumel 1792 ) 1793 1794 kernel = self.kernel_type( 1795 *tiled_groups, 1796 reduction_hint=reduction_hint_val, 1797 mutations=mutations, 1798 index_dtype=index_dtype, 1799 ) 1800 1801 self.codegen_node_schedule_with_kernel(node_schedule, kernel) 1802 with config.patch( 1803 "benchmark_kernel", benchmark_kernel 1804 ), V.set_kernel_handler(kernel): 1805 src_code = kernel.codegen_kernel() 1806 else: 1807 template_node = nodes[0] 1808 epilogue_nodes = nodes[1:] 1809 1810 with config.patch("benchmark_kernel", benchmark_kernel): 1811 src_code = self.codegen_template( 1812 template_node, epilogue_nodes, only_gen_src_code=True 1813 ) 1814 1815 src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") 1816 return src_code 1817 1818 def codegen_comment(self, node_schedule): 1819 pass 1820 1821 def define_kernel(self, src_code, node_schedule, kernel): 1822 raise NotImplementedError 1823 1824 1825@dataclasses.dataclass 1826class CandidateTiling: 1827 tiling: Tuple[sympy.Expr, sympy.Expr] 1828 score: int # higher is better 1829 name: Optional[str] = None 1830 1831 @staticmethod 1832 def is_good_size(s): 1833 """Somewhat arbitrary heuristic used to boost scores for some sizes""" 1834 s = V.graph.sizevars.size_hint(s) 1835 return s >= 32 and (s % 32 == 0) 1836 1837 1838class DisableReduction: 1839 """ 1840 Marker to invoke `kernel.disable_reduction()`. This closes a 1841 reduction loop and allows for pointwise ops to occur on the output 1842 of a reduction. 1843 """ 1844 1845 1846class EnableReduction: 1847 """ 1848 Marker to end a DisableReduction block. 1849 """ 1850 1851 @staticmethod 1852 def filter(node_schedule): 1853 """ 1854 Get the nodes from node_schedule skipping those in a 1855 DisableReduction block. 1856 """ 1857 disabled = False 1858 for node in node_schedule: 1859 if node in (EnableReduction, DisableReduction): 1860 # Don't tile stuff outside the main reduction loop 1861 disabled = node is DisableReduction 1862 elif disabled: 1863 pass 1864 else: 1865 yield node 1866 1867 1868class CantSplit(Exception): 1869 pass 1870