1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import dataclasses 5import functools 6import itertools 7import logging 8import os 9import textwrap 10from functools import lru_cache 11from typing import ( 12 Any, 13 Callable, 14 cast, 15 Dict, 16 Iterable, 17 List, 18 Optional, 19 Tuple, 20 TYPE_CHECKING, 21 Union, 22) 23 24import sympy 25 26import torch 27import torch._logging 28from torch._dynamo.utils import preserve_rng_state 29from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties 30from torch._prims_common import is_integer_dtype 31from torch.utils._ordered_set import OrderedSet 32from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing 33from torch.utils._triton import has_triton_package 34 35from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT 36from ...utils._sympy.value_ranges import ValueRanges 37from .. import config, ir 38from ..codecache import code_hash, get_path, PyCodeCache 39from ..metrics import is_metric_table_enabled, log_kernel_metadata 40from ..runtime.benchmarking import benchmarker 41from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK 42from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2 43from ..utils import ( 44 cache_on_self, 45 get_bounds_index_expr, 46 get_fused_kernel_name, 47 get_kernel_metadata, 48 is_welford_reduction, 49 Placeholder, 50 sympy_dot, 51 sympy_subs, 52) 53from ..virtualized import _ops as ops, OpsHandler, ReductionType, StoreMode, V 54from ..wrapper_benchmark import get_kernel_category_by_source_code 55from .common import ( 56 BackendFeature, 57 CSE, 58 CSEVariable, 59 DeferredLine, 60 IndentedBuffer, 61 OpOverrides, 62 PythonPrinter, 63 SizeArg, 64 TensorArg, 65 WorkspaceArg, 66) 67from .simd import ( 68 constant_repr, 69 IterationRangesEntry, 70 IterationRangesRoot, 71 pexpr, 72 SIMDKernel, 73 SIMDScheduling, 74) 75from .triton_utils import ( 76 config_of, 77 should_unwrap_unspec_arg, 78 signature_of, 79 signature_to_meta, 80) 81 82 83if TYPE_CHECKING: 84 from ..ir import IRNode 85 86log = logging.getLogger(__name__) 87perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") 88schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") 89fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") 90 91 92@lru_cache(None) 93def gen_attr_descriptor_import(): 94 """ 95 import AttrsDescriptor if the triton version is new enough to have this 96 class defined. 97 """ 98 if not has_triton_package(): 99 return "" 100 101 import triton.compiler.compiler 102 103 if hasattr(triton.compiler.compiler, "AttrsDescriptor"): 104 return "from triton.compiler.compiler import AttrsDescriptor" 105 else: 106 return "" 107 108 109@lru_cache(None) 110def gen_common_triton_imports(): 111 imports = IndentedBuffer() 112 imports.splice( 113 """ 114 import triton 115 import triton.language as tl 116 """ 117 ) 118 if attr_desc := gen_attr_descriptor_import(): 119 imports.writeline(attr_desc) 120 121 imports.splice( 122 """ 123 from torch._inductor.runtime import triton_helpers, triton_heuristics 124 from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math 125 from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties 126 """ 127 ) 128 return imports.getvalue() 129 130 131block_offsets = { 132 symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True) 133 for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX] 134} 135 136block_sizes = { 137 symt: sympy.Symbol(f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True) 138 for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.RINDEX] 139} 140 141 142@dataclasses.dataclass 143class IndexingOptions: 144 index_str: str 145 mask_vars: OrderedSet[str] 146 mask_str: str 147 expand_str: Optional[str] 148 _has_rindex: bool 149 index: sympy.Expr 150 151 def has_mask(self): 152 return bool(self.mask_vars) 153 154 def has_indirect(self): 155 return free_symbol_is_type(self.index, SymT.TMP) 156 157 def has_rindex(self): 158 return self._has_rindex 159 160 def has_tmpmask(self): 161 return "tmp" in self.mask_str 162 163 def has_rmask(self): 164 return "rmask" in self.mask_str 165 166 167@dataclasses.dataclass 168class BlockPtrOptions: 169 params: BlockParameters 170 constant_offset: sympy.Expr 171 order: List[int] 172 mask_vars: OrderedSet[str] 173 reshape_suffix: List[str] 174 175 @property 176 def shape(self) -> List[sympy.Expr]: 177 return self.params.shape 178 179 @property 180 def block_shape(self) -> List[sympy.Expr]: 181 return self.params.block_shape 182 183 @property 184 def strides(self) -> List[sympy.Expr]: 185 return self.params.strides 186 187 @property 188 def offsets(self) -> List[sympy.Expr]: 189 return self.params.offsets 190 191 @staticmethod 192 def create( 193 *, 194 params: BlockParameters, 195 constant_offset: sympy.Expr, 196 range_trees: List[IterationRangesEntry], 197 mask_vars: OrderedSet[str], 198 ) -> BlockPtrOptions: 199 """Helper to create a BlockPtrOptions instance""" 200 reshape_suffix = [f"{t.prefix.upper()}BLOCK" for t in range_trees] 201 202 # Only drop broadcast dims if the output has the same 203 # rank as the block. Otherwise, we will get shape errors. 204 drop_broadcasts = len(reshape_suffix) == len(params.strides) 205 206 broadcasting_dim = [s == 0 for s in params.strides] 207 for i, is_broadcasting in enumerate(broadcasting_dim): 208 if is_broadcasting and drop_broadcasts: 209 # drop any stride==0 dimensions for performance 210 reshape_suffix[i] = "1" 211 212 if V.kernel.no_x_dim: 213 assert range_trees[0].prefix == "x" 214 reshape_suffix.pop(0) 215 216 if ( 217 not V.kernel.inside_reduction 218 and len(params.strides) == len(V.kernel.numels) - 1 219 and V.kernel.numels[-1] != 1 220 ): 221 # Need to expand rank by 1 to match rank when self.inside_reduction=True 222 reshape_suffix.append("1") 223 224 def filter(it): 225 """Removes any broadcasting dims from a given sequence""" 226 assert len(it) == len(broadcasting_dim) 227 return [ 228 item 229 for item, is_broadcasting in zip(it, broadcasting_dim) 230 if not is_broadcasting or not drop_broadcasts 231 ] 232 233 # Drop broadcasting dimensions from the input. 234 params = BlockParameters( 235 **{key: filter(val) for key, val in dataclasses.asdict(params).items()} 236 ) 237 238 def lookup_size(exprs: Iterable[sympy.Expr]) -> List[sympy.Expr]: 239 return [V.graph.sizevars.lookup_precomputed_size(expr) for expr in exprs] 240 241 # Look up precomputed sizes 242 params.shape = lookup_size(params.shape) 243 params.strides = lookup_size(params.strides) 244 245 return BlockPtrOptions( 246 params=params, 247 constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset), 248 order=list(reversed(range(len(params.shape)))), 249 mask_vars=mask_vars, 250 reshape_suffix=reshape_suffix, 251 ) 252 253 def replace_roffset(self, expr: sympy.Expr, replacement: sympy.Expr) -> sympy.Expr: 254 """ 255 Replaces instances of roffset with the new expression. 256 """ 257 roffset = block_offsets[SymT.RINDEX] 258 return sympy_subs(expr, {roffset: replacement}) 259 260 def format(self, name: str, roffset=True) -> str: 261 """ 262 Codegen a call to tl.make_block_ptr() 263 264 Args: 265 name: variable name for pointer 266 roffset: should roffset be included in offsets=..., for use with tl.advance() 267 268 Returns: 269 "tl.make_block_ptr(...)" 270 """ 271 f = V.kernel.index_to_str 272 offsets = [*self.offsets] 273 if not roffset: 274 offsets = [ 275 self.replace_roffset(offset, sympy.Integer(0)) for offset in offsets 276 ] 277 args = [ 278 f"{name} + ({f(self.constant_offset)})" 279 if self.constant_offset != 0 280 else name, 281 f"shape={f(self.shape)}", 282 f"strides={f(self.strides)}", 283 f"block_shape={f(self.block_shape)}", 284 f"order={f(self.order)}", 285 f"offsets={f(offsets)}", 286 ] 287 return f"tl.make_block_ptr({', '.join(args)})" 288 289 @cache_on_self 290 def boundary_check(self) -> List[int]: 291 """List of indices to pass to tl.load(boundary_check=...)""" 292 sizevars = V.graph.sizevars 293 294 # Substitute maximum block sizes in shape expressions. 295 # This works in multiple_of checks because block sizes are powers of 2. 296 block_to_max: Dict[sympy.Expr, Any] = { 297 block_size: TRITON_MAX_BLOCK[prefix_str[symt].upper()] 298 for symt, block_size in block_sizes.items() 299 } 300 301 return [ 302 idx 303 for idx in range(len(self.shape)) 304 if ( 305 not sizevars.statically_known_equals( 306 self.strides[idx], sympy.Integer(0) 307 ) 308 and not sizevars.statically_known_multiple_of( 309 self.shape[idx], self.block_shape[idx] 310 ) 311 and not sizevars.statically_known_multiple_of( 312 self.shape[idx], sympy_subs(self.block_shape[idx], block_to_max) 313 ) 314 and not ( 315 V.kernel.no_x_dim 316 and self.block_shape[idx] == block_sizes[SymT.XBLOCK] 317 ) 318 ) 319 ] 320 321 def advance_roffset(self): 322 """ 323 Codegen string to pass to tl.advance(name, ...). 324 325 Advance is the difference between offsets in each loop iteration. 326 To compute it, we replace roffset with multiples of RBLOCK. 327 Since we expect roffset to vary in range(0, rnumel, RBLOCK), the first 328 iteration has roffset=0, while the second has roffset=RBLOCK. 329 """ 330 rblock = block_sizes[SymT.RINDEX] 331 advance = [ 332 ( 333 self.replace_roffset(offset, rblock) 334 - self.replace_roffset(offset, sympy.Integer(0)) 335 ) 336 for offset in self.offsets 337 ] 338 return V.kernel.index_to_str(advance) 339 340 def has_indirect(self): 341 return False # block_ptr can't do indirect indexing 342 343 def has_rindex(self) -> bool: 344 return any(free_symbol_is_type(expr, SymT.RINDEX) for expr in self.block_shape) 345 346 def has_rmask(self): 347 return self.has_rindex() 348 349 def has_tmpmask(self): 350 return False # block_ptr can't do indirect indexing 351 352 def has_mask(self): 353 return bool(self.boundary_check()) 354 355 356def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]): 357 """Workaround https://github.com/openai/triton/issues/2836""" 358 assert isinstance(old_shape, list) and isinstance(new_shape, list) 359 if old_shape == new_shape: 360 return value 361 if [s for s in new_shape if s != "1"] != old_shape: 362 return f"tl.reshape({value}, [{', '.join(new_shape)}])" 363 # rewrite to [:, None] syntax, which is less buggy 364 idx = 0 365 expand = [] 366 for size in new_shape: 367 if idx < len(old_shape) and size == old_shape[idx]: 368 expand.append(":") 369 idx += 1 370 else: 371 assert size == "1" 372 expand.append("None") 373 assert idx == len(old_shape) 374 return f"{value}[{', '.join(expand)}]" 375 376 377# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a 378# number of operators which Triton "implements", but in a way that is 379# inconsistent with Python semantics (and consistent with C semantics). We 380# must override all of these, or it is potential silent correctness problem 381class TritonPrinter(PythonPrinter): 382 def _print_TruncToInt(self, expr): 383 assert len(expr.args) == 1 384 return ( 385 f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" 386 ) 387 388 def _print_ToFloat(self, expr): 389 assert len(expr.args) == 1 390 return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" 391 392 def _print_PythonMod(self, expr): 393 quot, div = expr.args 394 quot_s = self._print(quot) 395 div_s = self._print(div) 396 if quot.is_nonnegative and div.is_nonnegative: 397 return f"{self.paren(quot_s)} % {self.paren(div_s)}" 398 return f"triton_helpers.remainder_integer({quot_s}, {div_s})" 399 400 def _print_FloorDiv(self, expr): 401 assert expr.is_integer 402 quot, div = expr.args 403 quot_s = self._print(quot) 404 div_s = self._print(div) 405 if quot.is_nonnegative and div.is_nonnegative: 406 return f"({self.paren(quot_s)} // {self.paren(div_s)})" 407 return f"triton_helpers.div_floor_integer({quot_s}, {div_s})" 408 409 # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher 410 # precision algorithm, which we would need to replicate here 411 def _print_IntTrueDiv(self, expr): 412 lhs, rhs = expr.args 413 return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" 414 415 # NB: sympy.floor/ceiling produce integers, so we have to do the 416 # conversion to index dtype 417 def _print_floor(self, expr): 418 assert len(expr.args) == 1 419 return ( 420 f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" 421 ) 422 423 def _print_FloorToInt(self, expr): 424 assert len(expr.args) == 1 425 return ( 426 f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" 427 ) 428 429 def _print_ceiling(self, expr): 430 assert len(expr.args) == 1 431 return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" 432 433 def _print_CeilToInt(self, expr): 434 assert len(expr.args) == 1 435 return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" 436 437 def _helper_sqrt(self, expr): 438 return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))" 439 440 def _print_FloatPow(self, expr): 441 return ( 442 f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" 443 ) 444 445 _print_PowByNatural = _print_FloatPow 446 447 def _print_Where(self, expr): 448 c = self.doprint(expr.args[0]) 449 p = self.doprint(expr.args[1]) 450 q = self.doprint(expr.args[2]) 451 return f"tl.where({c}, {p}, {q})" 452 453 def _print_min_max_helper(self, expr: sympy.Expr, cmp: str) -> str: 454 """ 455 Helper for max/min code genereration. 456 cmp: > or < 457 """ 458 nargs = len(expr.args) 459 if len(expr.args) == 1: 460 return self._print(expr.args[0]) 461 462 mid = len(expr.args) // 2 463 cls = type(expr) 464 a = self._print(cls(*expr.args[:mid])) 465 b = self._print(cls(*expr.args[mid:])) 466 467 # Use a macro so we can propagate constexprs. 468 # https://github.com/triton-lang/triton/issues/3815 469 a, b = tuple(f"({x})" for x in (a, b)) 470 assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'" 471 return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))" 472 473 def _print_Min(self, expr): 474 return self._print_min_max_helper(expr, "<") 475 476 def _print_Max(self, expr): 477 return self._print_min_max_helper(expr, ">") 478 479 def _print_Abs(self, expr): 480 assert len(expr.args) == 1 481 return f"tl_math.abs({self._print(expr.args[0])})" 482 483 def _print_OpaqueUnaryFn_cos(self, expr): 484 assert len(expr.args) == 1 485 return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))" 486 487 def _print_OpaqueUnaryFn_cosh(self, expr): 488 assert len(expr.args) == 1 489 return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))" 490 491 def _print_OpaqueUnaryFn_acos(self, expr): 492 assert len(expr.args) == 1 493 return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))" 494 495 def _print_OpaqueUnaryFn_sin(self, expr): 496 assert len(expr.args) == 1 497 return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))" 498 499 def _print_OpaqueUnaryFn_sinh(self, expr): 500 assert len(expr.args) == 1 501 return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))" 502 503 def _print_OpaqueUnaryFn_asin(self, expr): 504 assert len(expr.args) == 1 505 return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))" 506 507 def _print_OpaqueUnaryFn_tan(self, expr): 508 assert len(expr.args) == 1 509 return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))" 510 511 def _print_OpaqueUnaryFn_tanh(self, expr): 512 assert len(expr.args) == 1 513 return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))" 514 515 def _print_OpaqueUnaryFn_atan(self, expr): 516 assert len(expr.args) == 1 517 return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" 518 519 def _print_RoundToInt(self, expr): 520 assert len(expr.args) == 1 521 return f"libdevice.llrint({self._print(expr.args[0])})" 522 523 def _print_RoundDecimal(self, expr): 524 assert len(expr.args) == 2 525 number, ndigits = expr.args 526 if number.is_integer: 527 # ndigits < 0 should have been filtered by the sympy function 528 assert ndigits < 0 529 raise ValueError( 530 f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." 531 ) 532 return f"libdevice.nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits}" 533 534 535texpr = TritonPrinter().doprint 536 537 538def triton_compute_type(dtype): 539 triton_type_name = str(dtype).split(".")[-1] 540 if triton_type_name == "bool": 541 triton_type_name = "int1" 542 elif ( 543 triton_type_name in ("float16", "bfloat16") 544 and config.triton.codegen_upcast_to_fp32 545 ): 546 # float16 math is done in float32 inside the kernel 547 triton_type_name = "float32" 548 elif triton_type_name == "float8_e4m3fn": 549 triton_type_name = "float8e4nv" 550 elif triton_type_name == "float8_e5m2": 551 triton_type_name = "float8e5" 552 elif triton_type_name == "float8_e4m3fnuz": 553 triton_type_name = "float8e4b8" 554 elif triton_type_name == "float8_e5m2fnuz": 555 triton_type_name = "float8e5b16" 556 return f"tl.{triton_type_name}" 557 558 559def _get_primitive_bitwidth(dtype): 560 if hasattr(dtype, "is_floating_point"): 561 if dtype.is_floating_point: 562 # triton_compute_type changes the bitwidth 563 if ( 564 dtype in [torch.bfloat16, torch.float16] 565 and config.triton.codegen_upcast_to_fp32 566 ): 567 return 32 568 return torch.finfo(dtype).bits 569 else: 570 return torch.iinfo(dtype).bits 571 else: 572 return -1 573 574 575def triton_store_type(dtype): 576 triton_type_name = str(dtype).split(".")[-1] 577 if triton_type_name == "bool": 578 triton_type_name = "int8" 579 elif triton_type_name == "float8_e4m3fn": 580 triton_type_name = "float8e4nv" 581 elif triton_type_name == "float8_e5m2": 582 triton_type_name = "float8e5" 583 return f"tl.{triton_type_name}" 584 585 586def triton_acc_type(dtype): 587 if is_integer_dtype(dtype) and dtype.is_signed: 588 nbits = 64 if dtype == torch.int64 else 32 589 return f"tl.int{nbits}" 590 return triton_compute_type(dtype) 591 592 593class TritonCSEVariable(CSEVariable): 594 def __init__(self, name, bounds: ValueRanges[Any]) -> None: 595 super().__init__(name, bounds) 596 # We'll use this to track which masks the variable needs when used for indirect indexing 597 self.mask_vars: OrderedSet[str] = OrderedSet() 598 599 def update_on_args(self, name, args, kwargs): 600 for arg in args: 601 if isinstance(arg, TritonCSEVariable): 602 self.mask_vars.update(arg.mask_vars) 603 elif isinstance(arg, sympy.Symbol) and arg.name[0] in "xyr": 604 # most of the time index vars don't need masks associated with them 605 # however, when index vars are used to compute indices for indirect reads 606 # those reads should subsequently be masked, 607 self.mask_vars.update({f"{arg.name[0]}mask"}) 608 609 610class TritonOverrides(OpOverrides): 611 """Map element-wise ops to Triton""" 612 613 @staticmethod 614 def to_dtype( 615 x, 616 dtype: torch.dtype, 617 src_dtype: Optional[torch.dtype] = None, 618 use_compute_types=True, 619 ): 620 def _get_min_elements_per_thread( 621 src_dtype: torch.dtype, dst_dtype: torch.dtype 622 ) -> int: 623 if src_dtype == dst_dtype: 624 # No data type conversion is needed. No requirements on min_elem_per_thread. 625 return 0 626 627 # fp8 data type conversions has min_elem_per_thread requirements. 628 # Refer to Triton implementations here: 629 # https://github.com/openai/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. 630 fp8_dtypes = ( 631 torch.float8_e4m3fn, 632 torch.float8_e5m2, 633 ) 634 # Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2. 635 assert not ( 636 src_dtype in fp8_dtypes 637 and dst_dtype in fp8_dtypes 638 and src_dtype != dst_dtype 639 ), "Conversions between float8_e5m2 and float8_e4m3fn is not supported!" 640 if src_dtype == torch.float8_e5m2 or dst_dtype == torch.float8_e5m2: 641 return 4 642 if src_dtype == torch.float8_e4m3fn or dst_dtype == torch.float8_e4m3fn: 643 return 2 644 # No requirements on min_elem_per_thread. 645 return 0 646 647 if src_dtype is not None: 648 # Both dtype and src_dtype are set. This is used by torch to(dtype=dtype). 649 # It takes the maximum min_elem_per_thread if there are multiple fp8 conversions 650 # in the same kernel. 651 V.kernel.min_elem_per_thread = max( 652 _get_min_elements_per_thread(src_dtype, dtype), 653 V.kernel.min_elem_per_thread, 654 ) 655 656 if dtype == torch.bool: 657 return f"({x} != 0)" 658 elif dtype == torch.uint8: 659 # to work around llvm uint conversion semantics 660 # that produces 0's for negative values 661 return f"{x}.to(tl.int8).to(tl.uint8)" 662 663 if use_compute_types: 664 out_dtype = triton_compute_type(dtype) 665 else: 666 out_dtype = triton_store_type(dtype) 667 668 return f"{x}.to({out_dtype})" 669 670 @staticmethod 671 def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): 672 triton_dtype = triton_compute_type(dtype) 673 # We may promote float16 or bfloat16 to float32 and cause the 674 # bitwidth of dtype to be different from the input tensor (i.e. float32). 675 # In such as case, we will have to convert the input tensor to 676 # its src_type, perform bitcast, and then convert the bit-casted 677 # tensor back to float to ensure we use values with the right precision. 678 if ( 679 src_dtype in (torch.float16, torch.bfloat16) 680 and config.triton.codegen_upcast_to_fp32 681 ): 682 triton_src_dtype = str(src_dtype).split(".")[-1] 683 cast_x = f"{x}.to(tl.{triton_src_dtype})" 684 if dtype in (torch.float16, torch.bfloat16): 685 triton_type_name = str(dtype).split(".")[-1] 686 triton_dtype = f"tl.{triton_type_name}" 687 cast_x = f"{cast_x}.to({triton_dtype}, bitcast=True)" 688 return f"{cast_x}.to(tl.float32)" 689 else: 690 src_dtype_bitwidth = _get_primitive_bitwidth(src_dtype) 691 target_dtype_bitwidth = _get_primitive_bitwidth(dtype) 692 bitcast = "True" if src_dtype_bitwidth == target_dtype_bitwidth else "False" 693 return f"{x}.to({triton_dtype}, bitcast={bitcast})" 694 695 @staticmethod 696 def _shaped_constant(value, dtype, shape): 697 type_ = torch._prims_common.dtype_to_type(dtype) 698 triton_val = constant_repr(type_(value)) 699 triton_type = triton_compute_type(dtype) 700 701 if triton_type == "tl.float32": 702 # Float constants are always f32 in triton 703 return triton_val 704 705 # NOTE: We use a tensor here in order to get the expected type. 706 # Otherwise, e.g. float64 constants would be trunctated to float32. 707 return f"tl.full({shape}, {triton_val}, {triton_type})" 708 709 @classmethod 710 def constant(cls, value, dtype): 711 return cls._shaped_constant(value, dtype, shape=[]) 712 713 @staticmethod 714 def abs(x): 715 return f"tl_math.abs({x})" 716 717 @staticmethod 718 def libdevice_abs(x): 719 return f"libdevice.abs({x})" 720 721 @staticmethod 722 def exp(x): 723 return f"tl_math.exp({x})" 724 725 @staticmethod 726 def libdevice_exp(x): 727 return f"libdevice.exp({x})" 728 729 @staticmethod 730 def exp2(x): 731 return f"libdevice.exp2({x})" 732 733 @staticmethod 734 def expm1(x): 735 return f"libdevice.expm1({x})" 736 737 @staticmethod 738 def sqrt(x): 739 return f"libdevice.sqrt({x})" 740 741 @staticmethod 742 def libdevice_sqrt(x): 743 return f"libdevice.sqrt({x})" 744 745 @staticmethod 746 def relu(x): 747 bug = config.triton.inject_relu_bug_TESTING_ONLY 748 if bug == "compile_error": 749 return "compile error!" 750 elif bug == "runtime_error": 751 # NB: this only triggers runtime error as long as input 752 # is not all zero 753 return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})' 754 elif bug == "accuracy": 755 return f"{x} + 1" 756 elif bug is None: 757 return ops.maximum(ops.constant(0, torch.int32), x) 758 else: 759 raise AssertionError( 760 f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}" 761 ) 762 763 @staticmethod 764 def minimum(a, b): 765 return f"triton_helpers.minimum({a}, {b})" 766 767 @staticmethod 768 def maximum(a, b): 769 return f"triton_helpers.maximum({a}, {b})" 770 771 @staticmethod 772 def where(a, b, c): 773 return f"tl.where({a}, {b}, {c})" 774 775 @staticmethod 776 def inline_asm_elementwise( 777 *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1 778 ): 779 triton_type = triton_compute_type(dtype) 780 input_refs = ", ".join([str(i) for i in inputs]) 781 if constraints is None: 782 constraints = ", ".join(["=r"] + ["r" for _ in inputs]) 783 return f"tl.inline_asm_elementwise('{asm}', '{constraints}', [{input_refs}], dtype={triton_type}, is_pure={is_pure}, pack={pack})" # noqa: B950 784 785 @staticmethod 786 def cos(x): 787 return f"tl_math.cos({x})" 788 789 @staticmethod 790 def libdevice_cos(x): 791 return f"libdevice.cos({x})" 792 793 @staticmethod 794 def sin(x): 795 return f"tl_math.sin({x})" 796 797 @staticmethod 798 def libdevice_sin(x): 799 return f"libdevice.sin({x})" 800 801 @classmethod 802 def index_expr(cls, expr, dtype): 803 raise NotImplementedError("ops.index_expr not implemented outside a kernel") 804 805 @staticmethod 806 def masked(mask, body, other): 807 raise NotImplementedError("ops.masked not implemented outside a kernel") 808 809 @staticmethod 810 def lgamma(x): 811 return f"libdevice.lgamma({x})" 812 813 @staticmethod 814 def erf(x): 815 return f"libdevice.erf({x})" 816 817 @staticmethod 818 def cosh(x): 819 return f"libdevice.cosh({x})" 820 821 @staticmethod 822 def sinh(x): 823 return f"libdevice.sinh({x})" 824 825 @staticmethod 826 def acos(x): 827 return f"libdevice.acos({x})" 828 829 @staticmethod 830 def acosh(x): 831 return f"libdevice.acosh({x})" 832 833 @staticmethod 834 def asin(x): 835 return f"libdevice.asin({x})" 836 837 @staticmethod 838 def asinh(x): 839 return f"libdevice.asinh({x})" 840 841 @staticmethod 842 def atan2(x, y): 843 return f"libdevice.atan2({x}, {y})" 844 845 @staticmethod 846 def atan(x): 847 return f"libdevice.atan({x})" 848 849 @staticmethod 850 def atanh(x): 851 return f"libdevice.atanh({x})" 852 853 @staticmethod 854 def copysign(x, y): 855 return f"libdevice.copysign({x}, {y})" 856 857 @staticmethod 858 def erfc(x): 859 return f"libdevice.erfc({x})" 860 861 @staticmethod 862 def erfinv(x): 863 return f"libdevice.erfinv({x})" 864 865 @staticmethod 866 def hypot(x, y): 867 return f"libdevice.hypot({x}, {y})" 868 869 @staticmethod 870 def log10(x): 871 return f"libdevice.log10({x})" 872 873 @staticmethod 874 def log2(x): 875 return f"libdevice.log2({x})" 876 877 @staticmethod 878 def nextafter(x, y): 879 return f"libdevice.nextafter({x}, {y})" 880 881 @staticmethod 882 def logical_and(a, b): 883 return f"{a} & {b}" 884 885 @staticmethod 886 def logical_not(a): 887 return f"{a} == 0" 888 889 @staticmethod 890 def logical_or(a, b): 891 return f"{a} | {b}" 892 893 @staticmethod 894 def logical_xor(a, b): 895 return f"({a} ^ {b})" 896 897 @staticmethod 898 def bitwise_and(a, b): 899 return f"{a} & {b}" 900 901 @staticmethod 902 def bitwise_not(a): 903 return f"~{a}" 904 905 @staticmethod 906 def bitwise_or(a, b): 907 return f"{a} | {b}" 908 909 @staticmethod 910 def bitwise_xor(a, b): 911 return f"{a} ^ {b}" 912 913 @staticmethod 914 def bitwise_left_shift(a, b): 915 return f"{a} << {b}" 916 917 @staticmethod 918 def bitwise_right_shift(a, b): 919 return f"{a} >> {b}" 920 921 @staticmethod 922 def rand(seed, offset): 923 offset = f"({offset}).to(tl.uint32)" 924 return f"tl.rand({seed}, {offset})" 925 926 @staticmethod 927 def randn(seed, offset): 928 offset = f"({offset}).to(tl.uint32)" 929 return f"tl.randn({seed}, {offset})" 930 931 @staticmethod 932 def randint64(seed, offset, low, high): 933 offset = f"({offset}).to(tl.uint32)" 934 return f"triton_helpers.randint64({seed}, {offset}, {low}, {high})" 935 936 @staticmethod 937 def load_seed(name, offset): 938 raise NotImplementedError("ops.load_seed not implemented outside a kernel") 939 940 @staticmethod 941 def rsqrt(x): 942 return f"libdevice.rsqrt({x})" 943 944 @staticmethod 945 def log1p(x): 946 return f"libdevice.log1p({x})" 947 948 @staticmethod 949 def tan(x): 950 return f"libdevice.tan({x})" 951 952 @staticmethod 953 def tanh(x): 954 return f"libdevice.tanh({x})" 955 956 @staticmethod 957 def sigmoid(x): 958 return f"tl.sigmoid({x})" 959 960 @staticmethod 961 def signbit(x): 962 # XX: This is wrong for the value -0.0 in floating point 963 return f"libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0" 964 965 @staticmethod 966 def fmod(a, b): 967 return f"libdevice.fmod({a}, {b})" 968 969 @staticmethod 970 def pow(a, b): 971 return f"libdevice.pow({a}, {b})" 972 973 @staticmethod 974 def log(x): 975 return f"tl_math.log({x})" 976 977 @staticmethod 978 def libdevice_log(x): 979 return f"libdevice.log({x})" 980 981 @staticmethod 982 def isinf(x): 983 return f"libdevice.isinf({x}).to(tl.int1)" 984 985 @staticmethod 986 def isnan(x): 987 return f"libdevice.isnan({x}).to(tl.int1)" 988 989 @staticmethod 990 def round(x): 991 return f"libdevice.nearbyint({x})" 992 993 @staticmethod 994 def floor(x): 995 return f"libdevice.floor({x})" 996 997 @staticmethod 998 def floordiv(a, b): 999 # See the comment in lowering.div_mode. a and b are integer type. 1000 # Similar to div_floor_kernel_cuda in pytorch core. 1001 # Notice that // in triton behaves as truncdiv instead of floordiv 1002 quot = f"{a} // {b}" 1003 rem = f"{a} % {b}" 1004 return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})" 1005 1006 @staticmethod 1007 def sign(x): 1008 z = ops.constant(0, torch.int32) 1009 left = ops.to_dtype((ops.lt(z, x)), torch.int8) 1010 right = ops.to_dtype((ops.lt(x, z)), torch.int8) 1011 sub = ops.sub(left, right) 1012 return f"{sub}.to({x}.dtype)" 1013 1014 @staticmethod 1015 def trunc(x): 1016 return f"libdevice.trunc({x})" 1017 1018 @staticmethod 1019 def truncdiv(a, b): 1020 # See the comment in lowering.div_mode. a and b are integer type. 1021 # Notice that // in triton behaves as truncdiv instead of floordiv 1022 return f"{a} // {b}" 1023 1024 @staticmethod 1025 def ceil(x): 1026 return f"libdevice.ceil({x})" 1027 1028 1029TritonOverrides._initialize_pointwise_overrides("triton") 1030 1031 1032# Use mypy to check protocol implemented correctly 1033def _typecheck_TritonOverrides(h: TritonOverrides) -> OpsHandler[str]: 1034 return h 1035 1036 1037class TritonKernelOverrides(TritonOverrides): 1038 """Map element-wise ops to Triton within a TritonKernel 1039 1040 Unlike TritonOverrides, these assume the code is going to be inserted into 1041 the body of the main triton kernel and so it may use indexing and mask 1042 variables which are assumed to already be defined in the current scope. 1043 """ 1044 1045 @classmethod 1046 def constant(cls, value, dtype): 1047 # NOTE: Cannot use shape=[] as it's not supported by triton-rocm 1048 # We could use shape=[1] instead but starting with the correct 1049 # ndim avoids extra `tt.expand_dim` ops appearing in the triton IR. 1050 ndim = V.kernel.triton_tensor_ndim() 1051 shape = [1] * ndim 1052 return cls._shaped_constant(value, dtype, shape=shape) 1053 1054 @classmethod 1055 def index_expr(cls, expr, dtype): 1056 indexing = V.kernel.indexing(expr, block_ptr=False) 1057 assert isinstance(indexing, IndexingOptions) 1058 var = V.kernel.cse.generate( 1059 V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr) 1060 ) 1061 1062 if dtype not in (torch.int32, torch.int64): 1063 var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype)) 1064 var.mask_vars = indexing.mask_vars 1065 return var 1066 1067 @staticmethod 1068 def masked(mask, body, other): 1069 if mask is not None and torch.version.hip is not None: 1070 mask = V.kernel.cse.generate( 1071 V.kernel.compute, 1072 f"{mask}.to(tl.int1)", 1073 ) 1074 1075 nodes = body.graph.find_nodes(op="output") 1076 assert nodes, "graph for body does not contain an output" 1077 1078 need_where = False 1079 for node in nodes: 1080 for arg in node.args: 1081 if arg.target != "load" or should_unwrap_unspec_arg(arg.args[0]): 1082 need_where = True 1083 1084 value = None if need_where else other 1085 with V.kernel.mask_loads(mask, value=value) as new_mask: 1086 result = body() 1087 1088 if need_where: 1089 # Remove once CSEVariables track the dtype 1090 if result.bounds.is_bool: 1091 other = bool(other) 1092 # Take dtype from result to prevent accidental promotion 1093 other = V.kernel.cse.generate( 1094 V.kernel.compute, 1095 f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", 1096 bounds=ValueRanges.wrap(other), 1097 ) 1098 ret = ops.where(new_mask, result, other) 1099 else: 1100 ret = result 1101 1102 ret.mask_vars.discard(new_mask) 1103 return ret 1104 1105 @staticmethod 1106 def load_seed(name, offset): 1107 var = V.kernel.args.input(name) 1108 return ( 1109 f"tl.load({var} + {V.kernel.args.seed_offset('load_seed_offset', offset)})" 1110 ) 1111 1112 @staticmethod 1113 def frexp(x): 1114 cache_key = f"frexp({x})" 1115 if cache_key in V.kernel.cse.cache: 1116 return V.kernel.cse.cache[cache_key] 1117 1118 mantissa = V.kernel.cse.newvar() 1119 exponent = V.kernel.cse.newvar() 1120 V.kernel.compute.writeline( 1121 f"{mantissa}, {exponent} = triton_helpers.frexp({x})" 1122 ) 1123 V.kernel.cse.cache[cache_key] = (mantissa, exponent) 1124 return (mantissa, exponent) 1125 1126 1127# Use mypy to check protocol implemented correctly 1128def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str]: 1129 return h 1130 1131 1132class HelperFunctions: 1133 """An ordered set of helper functions.""" 1134 1135 _templates_seen: Dict[str, str] # Template code to function name 1136 finalized_helpers: List[str] 1137 1138 def __init__(self) -> None: 1139 self._templates_seen = {} 1140 self.finalized_helpers = [] 1141 1142 def add(self, template_code: str, *, base_name="_triton_helper_fn") -> str: 1143 """This accepts a function definition with the function name 1144 left as a format specifier e.g. 1145 1146 @triton.jit 1147 def {name}(arg0, arg1): 1148 return arg0 + arg1 1149 1150 We add the templated code to the function set and return the name 1151 assigned to that function. 1152 1153 """ 1154 existing_name = self._templates_seen.get(template_code) 1155 if existing_name is not None: 1156 # Don't duplicate existing helpers 1157 return existing_name 1158 1159 name = f"{base_name}{len(self.finalized_helpers)}" 1160 self._templates_seen[template_code] = name 1161 self.finalized_helpers.append(template_code.format(name=name)) 1162 return name 1163 1164 def __iter__(self): 1165 return iter(self.finalized_helpers) 1166 1167 def __getitem__(self, idx): 1168 return self.finalized_helpers[idx] 1169 1170 1171@dataclasses.dataclass 1172class BlockParameters: 1173 """ 1174 Class representing ND block dimensions, for block pointer analysis. 1175 """ 1176 1177 shape: List[sympy.Expr] = dataclasses.field(default_factory=list) 1178 block_shape: List[sympy.Expr] = dataclasses.field(default_factory=list) 1179 strides: List[sympy.Expr] = dataclasses.field(default_factory=list) 1180 offsets: List[sympy.Expr] = dataclasses.field(default_factory=list) 1181 1182 def __add__(self, other: BlockParameters) -> BlockParameters: 1183 """ 1184 Concatenates block parameters. 1185 """ 1186 cls = type(self) 1187 a, b = tuple(dataclasses.asdict(x) for x in (self, other)) 1188 return cls(**{key: a[key] + b[key] for key in a}) 1189 1190 1191class TritonKernel(SIMDKernel): 1192 overrides = TritonKernelOverrides # type: ignore[assignment] 1193 helper_functions: HelperFunctions 1194 kexpr: Callable[[sympy.Expr], str] = texpr 1195 allow_block_ptr = True 1196 1197 def __init__( 1198 self, 1199 *groups, 1200 index_dtype: str, 1201 mutations: Optional[OrderedSet[str]] = None, 1202 pid_cache=None, 1203 reduction_hint=ReductionHint.DEFAULT, 1204 min_elem_per_thread=0, 1205 override_persistent_reduction=None, 1206 optimize_mask=True, 1207 ) -> None: 1208 self.optimize_mask: bool = optimize_mask 1209 super().__init__( 1210 *groups, 1211 index_dtype=index_dtype, 1212 mutations=mutations, 1213 reduction_hint=reduction_hint, 1214 pid_cache=pid_cache, 1215 override_persistent_reduction=override_persistent_reduction, 1216 ) 1217 self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment] 1218 self.outside_loop_vars: OrderedSet[Any] = OrderedSet() 1219 self.min_elem_per_thread = min_elem_per_thread 1220 self.block_ptr_id = itertools.count() 1221 self.helper_functions = HelperFunctions() 1222 1223 # A set of autotuning hints to pass as part of triton_meta 1224 self.autotune_hints: OrderedSet[AutotuneHint] = OrderedSet() 1225 self.triton_meta: Optional[Dict[str, object]] = None 1226 1227 self.codegen_range_tree() 1228 1229 def _get_symt(self, tree: IterationRangesEntry) -> SymT: 1230 prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()} 1231 return prefix_to_symt[tree.prefix] 1232 1233 def _get_block_size(self, tree: IterationRangesEntry) -> sympy.Symbol: 1234 return block_sizes[self._get_symt(tree)] 1235 1236 def _get_block_offset(self, tree: IterationRangesEntry) -> sympy.Symbol: 1237 return block_offsets[self._get_symt(tree)] 1238 1239 def _max_block_size(self, tree: IterationRangesEntry) -> int: 1240 return TRITON_MAX_BLOCK[tree.prefix.upper()] 1241 1242 def codegen_range_tree(self): 1243 for tree in self.range_trees: 1244 # reduction indexing goes inside a loop 1245 if not tree.is_loop: 1246 self.iteration_ranges_codegen_header(tree, self.body) 1247 if self.inside_reduction and self.range_trees[-1].is_loop: 1248 # workaround for this issue: 1249 # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 1250 self.body.writeline( 1251 f"rbase = {self.iteration_ranges_ranges_code(self.range_trees[-1])}" 1252 ) 1253 1254 def need_numel_args(self): 1255 r""" 1256 Indicate whether we need provide numel as arguments for the generated 1257 kernel calls in the benchmark. 1258 1259 Should be true for pointwise/reduction kernels but false for triton 1260 matmul kernels. 1261 """ 1262 return True 1263 1264 def should_use_persistent_reduction(self) -> bool: 1265 """ 1266 Heuristic to set self.persistent_reduction and add guards 1267 if needed. 1268 """ 1269 if not (self.inside_reduction and config.triton.persistent_reductions): 1270 return False 1271 threshold = { 1272 ReductionHint.INNER: 1024, 1273 }.get(self.reduction_hint, 64) 1274 1275 # If multi_kernel is enabled, we do more aggressive persistent reduction. 1276 # This may result in some persistent reductions slower than the 1277 # corresponding non-persistent reductions. MultiKernel will do benchmarking 1278 # to pick the faster one. 1279 if config.triton.multi_kernel: 1280 threshold *= 16 1281 last_numel = self.numels[-1] 1282 return V.graph.sizevars.statically_known_leq(last_numel, threshold) # type: ignore[arg-types] 1283 1284 def want_no_x_dim(self): 1285 return ( 1286 self.reduction_hint == ReductionHint.INNER 1287 and self.persistent_reduction 1288 and len(self.numels) == 2 1289 and V.graph.sizevars.statically_known_geq(self.numels[-1], 256) # type: ignore[arg-types] 1290 ) 1291 1292 @property 1293 def assert_function(self) -> str: 1294 return "tl.device_assert" 1295 1296 def indexing( 1297 self, 1298 index: sympy.Expr, 1299 *, 1300 copy_shape=None, 1301 dense_indexing=False, 1302 override_mask=None, 1303 block_ptr=False, 1304 ): 1305 """ 1306 Compute the index and mask to pass to tl.load() or tl.store() 1307 """ 1308 index = self.prepare_indexing(index) 1309 index_vars = index.free_symbols 1310 has_rindex = False 1311 1312 mask_vars: OrderedSet[str] = OrderedSet() 1313 for var in index_vars: 1314 assert isinstance(var, sympy.Symbol) 1315 has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX) 1316 if override_mask: 1317 pass 1318 elif symbol_is_type(var, SymT.TMP): 1319 # indirect indexing 1320 cse_var = self.cse.varname_map[var.name] 1321 mask_vars.update(cse_var.mask_vars) 1322 elif symbol_is_type( 1323 var, 1324 ( 1325 SymT.UNBACKED_INT, 1326 SymT.SIZE, 1327 SymT.PRECOMPUTED_SIZE, 1328 SymT.INDEX, 1329 SymT.FLOAT, 1330 SymT.UNBACKED_FLOAT, 1331 ), 1332 ): 1333 pass 1334 else: 1335 # var is one of xN, yN or rN 1336 assert symbol_is_type( 1337 var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK) 1338 ), var.name 1339 mask_vars.add(f"{var.name[0]}mask") 1340 1341 need_dense = ( 1342 config.triton.dense_indexing 1343 or dense_indexing 1344 or self._load_mask is not None 1345 ) and index != 0 1346 1347 have_dense = True 1348 have_loop_vars = False 1349 dense_mask_vars: OrderedSet[str] = OrderedSet() 1350 1351 for tree in self.active_range_trees(): 1352 if index_vars.intersection(tree.var_list): 1353 have_loop_vars = True 1354 else: 1355 have_dense = False 1356 dense_mask_vars.add(f"{tree.prefix}mask") 1357 1358 if ( 1359 block_ptr 1360 and self.allow_block_ptr 1361 and config.triton.use_block_ptr 1362 and not override_mask 1363 and not self._load_mask 1364 and len(mask_vars - dense_mask_vars) == 0 1365 and not self.is_indirect_indexing(index) 1366 and have_loop_vars 1367 # workaround https://github.com/openai/triton/issues/2821 1368 and self.index_dtype == "tl.int32" 1369 ): 1370 1371 def match_strided_block( 1372 index: sympy.Expr, range_tree: IterationRangesEntry 1373 ) -> Optional[BlockParameters]: 1374 """ 1375 Matches expressions of the form: 1376 idx = s * xindex 1377 1378 This implies stride (s,), and shape (XBLOCK,). 1379 """ 1380 symbol = range_tree.symbol() 1381 stride = sympy.Wild("stride", exclude=[symbol]) 1382 m = index.match(symbol * stride) 1383 if m is None: 1384 return None 1385 1386 return BlockParameters( 1387 shape=[range_tree.numel], 1388 block_shape=[self._get_block_size(range_tree)], 1389 strides=[m[stride]], 1390 offsets=[self._get_block_offset(range_tree)], 1391 ) 1392 1393 def match_mod_div_block( 1394 index: sympy.Expr, range_tree: IterationRangesEntry 1395 ) -> Optional[BlockParameters]: 1396 """ 1397 Matches higher-dimensional blocks coming from FloorDiv and ModularIndexing. 1398 1399 Example expression to match: 1400 sN * ((rindex//(d1 * ... * d(N-1)))) 1401 + s1 * ModularIndexing(rindex, 1, d1) 1402 + ... 1403 + s(N-1) * ModularIndexing(rindex, d1 * ... * d(N-2), d(N-1)) 1404 1405 This iterates over a block of shape (dN, ..., d1) and stride 1406 (sN, ..., s1). (d1,...,d(N-1)) and (s1,...,sN) are 1407 wildcards that we match. 1408 1409 Note that dN does not appear in the expression, but we solve for it 1410 using range tree numels and the other dims. 1411 """ 1412 # Bound the possible number of dims. We use the following heuristics: 1413 # - At least one dim for each range tree node. 1414 # - At least one dim for every FloorDiv or ModularIndexing op. 1415 # - At least 2 dims to pattern match. 1416 num_dims = max( 1417 2, 1418 len(self.range_tree_nodes), 1419 (index.count(FloorDiv) + index.count(ModularIndexing)), 1420 ) 1421 1422 # Pattern match to find the strides and offset. 1423 index_var = range_tree.symbol() 1424 wild = functools.partial(sympy.Wild, exclude=[index_var]) 1425 dims: List[sympy.Expr] = [ 1426 wild(f"dim_mod{idx}") for idx in range(num_dims) 1427 ] 1428 strides: List[sympy.Expr] = [ 1429 wild(f"stride_mod{idx}") for idx in range(num_dims) 1430 ] 1431 1432 def get_slice_numels(dims: List[Any]) -> List[Any]: 1433 """ 1434 Compute the cumulative size of each dimension's slice. 1435 This proceeds from the last dim up to the second. 1436 """ 1437 numels = [sympy.Integer(1)] 1438 for dim in dims[:0:-1]: 1439 numel = dim * numels[0] 1440 numels.insert(0, numel) 1441 return numels 1442 1443 # The first dimension's index is computed by division. 1444 # The remaining are computed by modulo. 1445 slice_numels = get_slice_numels(dims[:num_dims]) 1446 block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [ 1447 ModularIndexing(index_var, numel, dim) 1448 for dim, numel in zip(dims[1:], slice_numels[1:]) 1449 ] 1450 1451 # Calculate a linear index from block indices. 1452 match_expr = sympy_dot(strides, block_index_exprs) 1453 1454 # Pattern match. 1455 match = index.match(match_expr) 1456 if match is None: 1457 return None 1458 1459 # Provide default values for unmatched dims and strides. 1460 for dim in dims[1:]: 1461 if dim not in match: 1462 match[dim] = sympy.Integer(1) 1463 for stride in strides[1:]: 1464 if stride not in match: 1465 match[stride] = sympy.Integer(0) 1466 1467 sizevars = V.graph.sizevars 1468 1469 def get_match(expr: sympy.Expr) -> sympy.Expr: 1470 return sizevars.lookup_precomputed_size(match[expr]) 1471 1472 # Replace wildcards with matched expressions. 1473 dims = [dims[0]] + [get_match(dim) for dim in dims[1:]] 1474 strides = [get_match(stride) for stride in strides] 1475 slice_numels = get_slice_numels(dims) 1476 block_index_exprs = [ 1477 sympy_subs(expr, match) for expr in block_index_exprs 1478 ] 1479 1480 # The leading dimension is not directly matched in our expression. 1481 # We solve for it by dividing the range tree numel by the product of 1482 # all other dimensions. We quit if they are not known to be divisible. 1483 assert ( 1484 dims[0] not in match 1485 ), "Expected not to match the leading dimension!" 1486 if not sizevars.statically_known_multiple_of( 1487 range_tree.numel, slice_numels[0] 1488 ): 1489 return None 1490 dims[0] = range_tree.numel / slice_numels[0] 1491 1492 # Check for applicable iteration range sizes. 1493 # When mapping a 1D block into an ND one, we need to know that 1494 # the number of elements is not changed. This means the slice numels of 1495 # the ND iteration range must evenly divide the length of the 1D block. 1496 # There are two cases where we can guarantee this: 1497 # 1. Numels are powers of 2. If numel == 2 ** n, and we know XBLOCK == 2 ** m, 1498 # with n and m integers, then either numel is a multiple of XBLOCK, or numel 1499 # is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.) 1500 # 2. Numels are multiples of the maximum possible block size. 1501 max_block = self._max_block_size(range_tree) 1502 if any( 1503 not sizevars.statically_known_multiple_of(numel, max_block) 1504 and not sizevars.statically_known_power_of_2(numel) 1505 for numel in slice_numels 1506 ): 1507 return None 1508 1509 def identity(expr: sympy.Expr) -> sympy.Expr: 1510 return expr 1511 1512 # Compute the ND block shape from the linear block size. 1513 # Use CielDiv to round leading dimensions up to 1. 1514 # Non-leading dimensions are clamped to the size of the iteration range, 1515 # while the leading dimension can exceed this to accomodate a larger 1516 # block size. 1517 linear_block_size = self._get_block_size(range_tree) 1518 block_shape: List[sympy.Expr] = [ 1519 CeilDiv(linear_block_size, slice_numels[0]) 1520 ] + [ 1521 sympy.Min(CeilDiv(linear_block_size, numel), dim) 1522 for numel, dim in zip(slice_numels[1:], dims[1:]) 1523 ] 1524 1525 # Compute block offsets from {xyzr}offset and the matched expressions. 1526 block_offsets: List[sympy.Expr] = [ 1527 sympy_subs(expr, {index_var: self._get_block_offset(range_tree)}) 1528 for expr in block_index_exprs 1529 ] 1530 1531 return BlockParameters( 1532 shape=dims, 1533 block_shape=block_shape, 1534 strides=strides, 1535 offsets=block_offsets, 1536 ) 1537 1538 def match_block_pointer_subexpr( 1539 expr: sympy.Expr, range_tree: IterationRangesEntry 1540 ) -> Optional[BlockParameters]: 1541 """ 1542 Match a block indexing subexpression involving a single range tree. 1543 """ 1544 for match_func in ( 1545 match_strided_block, 1546 match_mod_div_block, 1547 ): 1548 match = match_func(expr, range_tree) 1549 if match is not None: 1550 return match 1551 1552 return None 1553 1554 def match_block_pointer() -> Optional[BlockPtrOptions]: 1555 index_relative_to_xyr_index = sympy_subs( 1556 index, {v: t.expr for v, t in self.range_tree_nodes.items()} 1557 ) 1558 range_trees = self.active_range_trees(reorder=True) 1559 1560 # Match each range tree separately. 1561 range_symbols = {tree.symbol() for tree in range_trees} 1562 index_terms = sympy.Add.make_args(index_relative_to_xyr_index) 1563 block_params = BlockParameters() 1564 for tree in range_trees: 1565 # Partition the index into subexpressions pertaining to each range tree. 1566 # For example xindex * 5 + rindex * 3 is partitioned to 1567 # (xindex * 5, rindex * 3). 1568 symbol = tree.symbol() 1569 subexpr = sympy.Integer(0) + sum( 1570 expr for expr in index_terms if symbol in expr.free_symbols 1571 ) 1572 1573 # Reject mixed terms, e.g. xindex * rindex. 1574 # NB: the zero expression is allowed, for broadcasting. 1575 if len(range_symbols.intersection(subexpr.free_symbols)) > 1: 1576 return None 1577 1578 # Match the subexpression for this range tree. 1579 params = match_block_pointer_subexpr(subexpr, tree) 1580 if params is None: 1581 return None 1582 block_params += params 1583 1584 # Collect leftover terms as a constant offset. 1585 offset = sum( 1586 expr 1587 for expr in index_terms 1588 if not range_symbols.intersection(expr.free_symbols) 1589 ) 1590 1591 # Form the block pointer. 1592 self.filter_masks(mask_vars) 1593 return BlockPtrOptions.create( 1594 params=block_params, 1595 constant_offset=offset, 1596 range_trees=range_trees, 1597 mask_vars=mask_vars, 1598 ) 1599 1600 # Return a block pointer, if indexing matches the pattern. 1601 options = match_block_pointer() 1602 if options is not None: 1603 return options 1604 1605 expand_str = None 1606 index_str = self.index_to_str(index) 1607 if isinstance(index, sympy.Integer): 1608 expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() 1609 index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" 1610 return IndexingOptions( 1611 index_str, OrderedSet(), "None", expand_str, has_rindex, index 1612 ) 1613 1614 if need_dense and not have_dense: 1615 expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() 1616 index_str = f"tl.broadcast_to({index_str}, {expand_str})" 1617 mask_vars = dense_mask_vars 1618 elif not have_loop_vars and copy_shape: 1619 index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" 1620 mask_vars = dense_mask_vars 1621 1622 if override_mask: 1623 mask_vars = OrderedSet([override_mask]) 1624 1625 if self._load_mask: 1626 mask_vars.add(self._load_mask) 1627 1628 self.filter_masks(mask_vars) 1629 1630 mask_str = " & ".join(sorted(map(str, mask_vars))) if mask_vars else "None" 1631 return IndexingOptions(index_str, mask_vars, mask_str, expand_str, has_rindex, index) # type: ignore[arg-type] 1632 1633 def codegen_block_ptr( 1634 self, name: str, var: str, indexing: BlockPtrOptions, other="" 1635 ) -> Tuple[str, Optional[DeferredLine], str]: 1636 advance_block_ptr = None 1637 check = indexing.boundary_check() 1638 if not check: 1639 # workaround https://github.com/openai/triton/issues/2813 1640 other = "" 1641 elif other: 1642 assert other == ", other=0.0" 1643 other = f", boundary_check={check!r}, padding_option='zero'" 1644 else: 1645 other = f", boundary_check={check!r}" 1646 if ( 1647 self.inside_reduction 1648 and self.range_trees[-1].is_loop 1649 and indexing.has_rindex() 1650 ): 1651 block_ptr = f"block_ptr{next(self.block_ptr_id)}" 1652 self.body.writeline( 1653 DeferredLine( 1654 name, f"{block_ptr} = {indexing.format(var, roffset=False)}" 1655 ) 1656 ) 1657 advance_block_ptr = DeferredLine( 1658 name, 1659 f"{block_ptr} = tl.advance({block_ptr}, {indexing.advance_roffset()})", 1660 ) 1661 else: 1662 block_ptr = indexing.format(var) 1663 return block_ptr, advance_block_ptr, other 1664 1665 def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""): 1666 # broadcasting is not implicit for block_ptrs 1667 value = ( 1668 f"tl.broadcast_to({value}, {self.index_to_str(indexing.reshape_suffix)})" 1669 ) 1670 # drop any extra size=1 dimensions 1671 block_shape = [V.kernel.index_to_str(expr) for expr in indexing.block_shape] 1672 value = triton_reshape(value, indexing.reshape_suffix, block_shape) 1673 # workaround https://github.com/openai/triton/issues/2814 1674 value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})" 1675 return f"tl.store({block_ptr}, {value}{other})" 1676 1677 def check_bounds( 1678 self, 1679 expr: sympy.Expr, 1680 size: sympy.Expr, 1681 lower: bool, 1682 upper: bool, 1683 ): 1684 if not (lower or upper): 1685 return 1686 1687 assert isinstance(expr, sympy.Expr) 1688 indexing = self.indexing(expr, block_ptr=False) 1689 assert isinstance(indexing, IndexingOptions) 1690 1691 index_str = indexing.index_str 1692 mask_str = indexing.mask_str if indexing.has_mask() else None 1693 size_str = texpr(self.rename_indexing(size)) if upper else None 1694 1695 # expr is already wrapped 1696 line = self.indirect_assert( 1697 index_str, "0" if lower else None, size_str, mask_str 1698 ) 1699 1700 indirect = self.is_indirect_indexing(expr) or any( 1701 isinstance(m, TritonCSEVariable) for m in indexing.mask_vars 1702 ) 1703 buffer = self.get_load_buffer(indexing) 1704 self.cse.generate(buffer, line, assignment=False) 1705 1706 def get_load_buffer(self, indexing): 1707 if indexing.has_indirect() or indexing.has_tmpmask(): 1708 # Masked loads must come after the mask is computed 1709 return self.compute 1710 elif ( 1711 self.inside_reduction 1712 and self.range_trees[-1].is_loop 1713 and not indexing.has_rindex() 1714 ): 1715 # can lift a common load outside of reduction loop 1716 # One exception is when this is an indirect_load. 1717 return self.body 1718 else: 1719 return self.loads 1720 1721 def load(self, name: str, index: sympy.Expr): 1722 var = self.args.input(name) 1723 indirect_indexing = self.is_indirect_indexing(index) 1724 original_index = index 1725 indexing = self.indexing(index, block_ptr=True) 1726 has_rindex = indexing.has_rindex() 1727 has_tmpmask = indexing.has_tmpmask() 1728 1729 # Keep the variable in cache if were going to reuse it. Equiv., if any of the following hold 1730 # 1) We are doing broadcasting 1731 # 2) It is a non-coalesced load. The intuition is that if it's 1732 # non-coalesced, we will likely load each element multiple times in 1733 # practice. 1734 # 3) It will be used later and it won't be CSE'd. Equiv., if all the following hold 1735 # 3.1) We are in a reduction loop 1736 # 3.2) Its not its last use 1737 # 3.3) This load will not be lifted to the body 1738 # 1739 is_coalesced = any( 1740 i == 1 for i in self.get_strides_of_load(original_index).values() 1741 ) 1742 if self.is_broadcasted(original_index): 1743 ep = ", eviction_policy='evict_last'" 1744 elif not is_coalesced: 1745 ep = ", eviction_policy='evict_last'" 1746 elif self.inside_reduction and self.range_trees[-1].is_loop: 1747 if name in self.args.inplace_buffers: 1748 names: OrderedSet[str] = OrderedSet( 1749 self.args.inplace_buffers[name].other_names 1750 ) 1751 else: 1752 names = OrderedSet([name]) 1753 last_use = len(names & self.last_usage) > 0 1754 evict_last = not last_use and (has_rindex or indirect_indexing) 1755 if evict_last: 1756 ep = ", eviction_policy='evict_last'" 1757 else: 1758 ep = ", eviction_policy='evict_first'" 1759 else: 1760 ep = "" 1761 1762 if (has_tmpmask or has_rindex) and indexing.has_mask(): 1763 if self._load_other: 1764 other = f", other={constant_repr(self._load_other)}" 1765 else: 1766 other = ", other=0.0" 1767 else: 1768 other = "" 1769 1770 advance_block_ptr = None 1771 append_broadcast = None 1772 if should_unwrap_unspec_arg(name): 1773 line = var 1774 else: 1775 if isinstance(indexing, BlockPtrOptions): 1776 block_ptr, advance_block_ptr, other = self.codegen_block_ptr( 1777 name, var, indexing, other 1778 ) 1779 line = f"tl.load({block_ptr}{other}{ep})" 1780 # add needed size=1 dimensions 1781 block_shape = [str(dim) for dim in indexing.block_shape] 1782 line = triton_reshape(line, block_shape, indexing.reshape_suffix) 1783 elif isinstance(original_index, sympy.Integer): 1784 line = f"tl.load({var} + ({original_index}))" 1785 append_broadcast = indexing.expand_str 1786 else: 1787 line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})" 1788 1789 dtype = V.graph.get_dtype(name) 1790 if ( 1791 dtype in (torch.float16, torch.bfloat16) 1792 and config.triton.codegen_upcast_to_fp32 1793 ): 1794 line += ".to(tl.float32)" 1795 if dtype == torch.bool and torch.version.hip is None: 1796 # Workaround for https://github.com/openai/triton/issues/2151 1797 # tl.load returns int8 when loading from pointer to int1 1798 # NOTE: Currently causes hangs on bool UTs for ROCm 1799 line += ".to(tl.int1)" 1800 1801 load_buffer = self.get_load_buffer(indexing) 1802 result_var = self.cse.generate(load_buffer, line) 1803 assert isinstance(result_var, TritonCSEVariable) 1804 result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] 1805 1806 if append_broadcast: 1807 line = f"tl.broadcast_to({result_var}, {append_broadcast})" 1808 result_var = self.cse.generate(load_buffer, line) 1809 1810 if advance_block_ptr: 1811 load_buffer.writeline(advance_block_ptr) 1812 1813 if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): 1814 self.outside_loop_vars.add(result_var) 1815 1816 return result_var 1817 1818 def store( 1819 self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None 1820 ) -> None: 1821 var = self.args.output(name) 1822 original_index = index 1823 indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None) 1824 1825 # Guard against write-after-read corruption in triton. 1826 # See # https://github.com/openai/triton/issues/1615 1827 # This triton bug means that a load which is broadcasted over multiple 1828 # warps may see the result of a store that happens later in the triton 1829 # program. The workaround is to add a barrier before storing, which 1830 # enforces that all warps have already read the data. 1831 is_inplace = name in self.args.inplace_buffers 1832 is_broadcasted = self.is_broadcasted(original_index) 1833 if is_inplace and is_broadcasted: 1834 self.stores.writeline(DeferredLine(name, "tl.debug_barrier()")) 1835 1836 advance_block_ptr = None 1837 if isinstance(indexing, BlockPtrOptions): 1838 block_ptr, advance_block_ptr, other = self.codegen_block_ptr( 1839 name, var, indexing 1840 ) 1841 # block_ptr stores don't do implicit casting 1842 line = self.codegen_block_ptr_store_line( 1843 name, indexing, block_ptr, value, other 1844 ) 1845 elif mode is None: 1846 line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})" 1847 elif mode == "atomic_add": 1848 line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str}, sem='relaxed')" 1849 else: 1850 raise NotImplementedError(f"store mode={mode}") 1851 self.stores.writeline(DeferredLine(name, line)) 1852 if advance_block_ptr: 1853 self.stores.writeline(advance_block_ptr) 1854 1855 if not self.inside_reduction: 1856 self.outside_loop_vars.add(value) 1857 1858 def bucketize( 1859 self, 1860 values: CSEVariable, 1861 offsets_name: str, 1862 offsets_size: sympy.Expr, 1863 indexing_dtype: torch.dtype, 1864 right: bool, 1865 ) -> CSEVariable: 1866 """ 1867 See [Note: Inductor bucketize op] 1868 """ 1869 1870 # Triton performance for bucketize_binary_search is much better when the number 1871 # of threads equals the number of elements. 1872 # If we're trying to use a bucketize kernel, we should make sure that an 1873 # autotuning config with num_elements_per_warp=32 exists. 1874 self.autotune_hints.add(AutotuneHint.ELEMENTS_PER_WARP_32) 1875 1876 offsets_ptr = self.args.input(offsets_name) 1877 block_size = self.dense_size_str() 1878 offsets_size_str = self.index_to_str(offsets_size) 1879 1880 if indexing_dtype == torch.int32: 1881 triton_dtype = "tl.int32" 1882 elif indexing_dtype == torch.int64: 1883 triton_dtype = "tl.int64" 1884 else: 1885 raise NotImplementedError( 1886 "Bucketize only supports indexing with int32 and int64" 1887 ) 1888 1889 result = self.cse.generate( 1890 self.compute, 1891 f"triton_helpers.bucketize_binary_search({values}, {offsets_ptr}, {triton_dtype}, {right}, {offsets_size_str}, {block_size})", # noqa: B950 line too long 1892 ) 1893 1894 return result 1895 1896 def reduction_resize(self, value): 1897 ndims = self.triton_tensor_ndim() 1898 if ndims == 1: 1899 return f"triton_helpers.promote_to_tensor({value})" 1900 1901 sizes = [":"] * ndims 1902 sizes[-1] = "None" 1903 return f"{value}[{', '.join(sizes)}]" 1904 1905 def reduction( 1906 self, 1907 dtype: torch.dtype, 1908 src_dtype: torch.dtype, 1909 reduction_type: ReductionType, 1910 value: Union[CSEVariable, Tuple[CSEVariable, ...]], 1911 ) -> Union[CSEVariable, Tuple[CSEVariable, ...]]: 1912 assert self.inside_reduction 1913 masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) 1914 self.filter_masks(masks) 1915 masks = sorted(masks) 1916 if self._load_mask: 1917 masks.append(self._load_mask) 1918 reduction_range_prefix = self.range_trees[-1].prefix 1919 1920 # Say we have 1921 # tmp0 = ops.constant(1, torch.int64) 1922 # tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0) 1923 # tmp0 in the triton code is either a scalar, or single-element tensor 1924 # so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1 1925 # To avoid this, we broadcast to the expected shape first. 1926 dense_size_str = self.dense_size_str() 1927 value = self._map_tuple_or_scalar( 1928 lambda v: self.cse.generate( 1929 self.compute, f"tl.broadcast_to({v}, {dense_size_str})" 1930 ), 1931 value, 1932 ) 1933 1934 dim: int 1935 root_op: str 1936 1937 def final_reduction(value): 1938 use_helper = reduction_type in {"any", "max", "min", "prod"} 1939 module = "triton_helpers" if use_helper else "tl" 1940 if reduction_type in {"max", "min"}: 1941 return self.reduction_resize( 1942 f"{module}.{reduction_type}2({value}, {dim})" 1943 ) 1944 return self.reduction_resize(f"{module}.{reduction_type}({value}, {dim})") 1945 1946 def final_argreduce(buffer, result_var, value, index): 1947 buffer.splice( 1948 f"""\ 1949 _, {result_var}_tmp = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) 1950 {result_var} = {self.reduction_resize(f'{result_var}_tmp')} 1951 """ 1952 ) 1953 1954 cache_key = (src_dtype, reduction_type, value) 1955 if cache_key in self.cse.reduction_cache: 1956 return self.cse.reduction_cache[cache_key] 1957 1958 dim = self.triton_tensor_ndim() - 1 1959 acc_type = triton_acc_type(src_dtype) 1960 result_var: Any = self.cse.newvar() 1961 result_var.mask_vars = OrderedSet(var for var in masks if var[0] != "r") 1962 cond = " & ".join(masks) 1963 1964 def where_cond(tval, fval): 1965 if not cond: 1966 return tval 1967 return TritonKernelOverrides.where(cond, tval, fval) 1968 1969 if self.persistent_reduction: 1970 default = ir.Reduction.default_value(reduction_type, src_dtype) 1971 default = self._map_tuple_or_scalar(constant_repr, default) 1972 1973 def _mask_value(value, default): 1974 return self.cse.generate(self.compute, where_cond(value, default)) 1975 1976 if isinstance(value, tuple): 1977 masked_value = [_mask_value(v, d) for v, d in zip(value, default)] 1978 else: 1979 masked_value = _mask_value(value, default) 1980 1981 if reduction_type in {"argmax", "argmin"}: 1982 accumulator_index = str( 1983 self.cse.generate( 1984 self.compute, 1985 f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", 1986 ) 1987 ) 1988 root_op = {"argmax": "max", "argmin": "min"}[reduction_type] 1989 final_argreduce( 1990 self.compute, result_var, masked_value, accumulator_index 1991 ) 1992 elif reduction_type == "welford_reduce": 1993 # For persistent reductions, don't bother with 1994 # welford's algorithm since it uses more registers, and 1995 # taking two reductions doesn't increase memory usage. 1996 result_var = self.welford_reduce_fallback(dtype, value) 1997 elif reduction_type == "welford_combine": 1998 mean, m2, weight = masked_value 1999 welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" 2000 mean, m2, weight = (self.cse.newvar() for _ in range(3)) 2001 self.compute.writeline(f"{mean}, {m2}, {weight} = {welford}") 2002 2003 result_var = tuple( 2004 self.cse.generate(self.compute, self.reduction_resize(var_name)) 2005 for var_name in (mean, m2, weight) 2006 ) 2007 else: 2008 result_var = self.cse.generate( 2009 self.compute, final_reduction(masked_value) 2010 ) 2011 else: 2012 accumulator = f"_{result_var}" 2013 default = ir.Reduction.default_accumulator(reduction_type, src_dtype) 2014 default = self._map_tuple_or_scalar(constant_repr, default) 2015 if not isinstance(default, tuple): 2016 self.body.writeline( 2017 f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" 2018 ) 2019 2020 if reduction_type in {"argmax", "argmin"}: 2021 accumulator_index = f"_{result_var}_index" 2022 long_max = torch.iinfo(torch.int64).max 2023 self.body.writeline( 2024 f"{accumulator_index} = tl.full({self.dense_size_str()}, {long_max}, tl.int64)" 2025 ) 2026 root_op = {"argmax": "max", "argmin": "min"}[reduction_type] 2027 2028 self.compute.splice( 2029 f"""\ 2030 {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( 2031 {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index 2032 ) 2033 {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} 2034 {accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)} 2035 """ 2036 ) 2037 final_argreduce(self.suffix, result_var, accumulator, accumulator_index) 2038 elif is_welford_reduction(reduction_type): 2039 accumulator = f"{result_var}_mean" 2040 accumulator_m2 = f"{result_var}_m2" 2041 accumulator_weight = f"{result_var}_weight" 2042 self.body.writeline( 2043 f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})" 2044 ) 2045 self.body.writeline( 2046 f"{accumulator_m2} = tl.zeros({self.dense_size_str()}, {acc_type})" 2047 ) 2048 self.body.writeline( 2049 f"{accumulator_weight} = tl.zeros({self.dense_size_str()}, {acc_type})" 2050 ) 2051 2052 if reduction_type == "welford_combine": 2053 mean, m2, weight = value 2054 self.compute.splice( 2055 f"""\ 2056 {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_combine( 2057 {accumulator}, {accumulator_m2}, {accumulator_weight}, 2058 {mean}, {m2}, {weight} 2059 ) 2060 """ 2061 ) 2062 else: 2063 assert reduction_type == "welford_reduce" 2064 self.compute.splice( 2065 f"""\ 2066 {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_reduce( 2067 {value}, {accumulator}, {accumulator_m2}, {accumulator_weight}, roffset == 0 2068 ) 2069 """ 2070 ) 2071 2072 self.compute.splice( 2073 f"""\ 2074 {accumulator} = {where_cond(f'{accumulator}_next', accumulator)} 2075 {accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)} 2076 {accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)} 2077 """ 2078 ) 2079 2080 result_mean = result_var 2081 result_m2 = self.cse.newvar() 2082 result_weight = self.cse.newvar() 2083 self.suffix.splice( 2084 f"""\ 2085 {result_mean}_tmp, {result_m2}_tmp, {result_weight}_tmp = triton_helpers.welford( 2086 {accumulator}, {accumulator_m2}, {accumulator_weight}, {dim} 2087 ) 2088 {result_mean} = {self.reduction_resize(f'{result_mean}_tmp')} 2089 {result_m2} = {self.reduction_resize(f'{result_m2}_tmp')} 2090 {result_weight} = {self.reduction_resize(f'{result_weight}_tmp')} 2091 """ 2092 ) 2093 result_var = result_mean, result_m2, result_weight 2094 else: 2095 combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) 2096 updated = combine_fn(accumulator, value) 2097 self.compute.writeline( 2098 f"{accumulator} = {where_cond(updated, accumulator)}" 2099 ) 2100 2101 if src_dtype == torch.bool: 2102 # This is only really used for aten.any. It changes the 2103 # final reduction of a non-persistent reduction from 2104 # tmp5 = triton_helpers.max(_tmp5, 1)[:, None] 2105 # to 2106 # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1) 2107 # which is needed because tl.reduce doesn't support tl.int1 2108 accumulator = f"{accumulator}.to(tl.int8)" 2109 result_type = triton_compute_type(dtype) 2110 self.suffix.writeline( 2111 f"{result_var} = {final_reduction(accumulator)}.to({result_type})" 2112 ) 2113 else: 2114 self.suffix.writeline( 2115 f"{result_var} = {final_reduction(accumulator)}" 2116 ) 2117 2118 self.cse.reduction_cache[cache_key] = result_var 2119 2120 if isinstance(result_var, tuple): 2121 assert all(isinstance(x, TritonCSEVariable) for x in result_var) 2122 self.outside_loop_vars |= OrderedSet(result_var) 2123 else: 2124 assert isinstance(result_var, TritonCSEVariable) 2125 self.outside_loop_vars.add(result_var) 2126 2127 return result_var 2128 2129 def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): 2130 assert self.inside_reduction 2131 self.inside_reduction = False 2132 indexing = self.indexing(index, block_ptr=True) 2133 self.inside_reduction = True 2134 var = self.args.output(name) 2135 2136 if isinstance(indexing, BlockPtrOptions): 2137 self.suffix.writeline( 2138 DeferredLine( 2139 name, 2140 self.codegen_block_ptr_store_line( 2141 name, 2142 indexing, 2143 indexing.format(var), 2144 value, 2145 f", boundary_check={indexing.boundary_check()!r}", 2146 ), 2147 ) 2148 ) 2149 else: 2150 assert isinstance(indexing, IndexingOptions) 2151 self.suffix.writeline( 2152 DeferredLine( 2153 name, 2154 f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})", 2155 ) 2156 ) 2157 2158 def _lift_helper(self, fn, num_args) -> str: 2159 # Lift IR function for scan operations into a triton function 2160 # in the global namespace 2161 helper = IndentedBuffer() 2162 helper.writeline("@triton.jit") 2163 args = [tuple(f"arg{i}_{n}" for n in range(num_args)) for i in range(2)] 2164 signature = ", ".join(itertools.chain.from_iterable(args)) 2165 helper.writeline(f"def {{name}}({signature}):") 2166 2167 cse = CSE(prefix="", suffix="") 2168 overrides = TritonOverrides(V.MockHandler()) 2169 2170 # Build a name that changes depending on fn to workaround a triton bug 2171 # where the combine_fn to reduce and scan is not hashed, and so different 2172 # scan ops may collide in the triton cache. 2173 # This is fixed with the latest triton pin, but not the triton-rocm pin. 2174 helper_name = "_triton_helper_fn" 2175 2176 class CSEProxy: 2177 def __getattr__(self, name: str) -> Callable[..., CSEVariable]: 2178 def inner(*args, **kwargs): 2179 nonlocal helper_name 2180 helper_name += f"_{name}" 2181 return cse.generate( 2182 helper, 2183 getattr(overrides, name)(*args, **kwargs), 2184 ) 2185 2186 return inner 2187 2188 with helper.indent(), V.set_ops_handler(CSEProxy()): 2189 outputs = fn(*args) 2190 outputs = ", ".join(str(output) for output in outputs) 2191 helper.writeline(f"return {outputs}") 2192 2193 return self.helper_functions.add(helper.getvalue(), base_name=helper_name) 2194 2195 def scan( 2196 self, 2197 dtypes: Tuple[torch.dtype, ...], 2198 combine_fn: Callable[ 2199 [Tuple[CSEVariable, ...], Tuple[CSEVariable, ...]], Tuple[CSEVariable, ...] 2200 ], 2201 values: Tuple[CSEVariable, ...], 2202 ) -> Tuple[CSEVariable, ...]: 2203 assert self.inside_reduction 2204 masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) 2205 self.filter_masks(masks) 2206 masks = sorted(masks) 2207 assert not self._load_mask, "ops.scan not supported inside ops.masked" 2208 reduction_range_prefix = self.range_trees[-1].prefix 2209 2210 broadcasted_values = [] 2211 accumulators = [] 2212 2213 cse_compute = functools.partial(self.cse.generate, self.compute) 2214 combine_helper_fn = self._lift_helper(combine_fn, len(values)) 2215 dim = self.triton_tensor_ndim() - 1 2216 2217 for value, dtype in zip(values, dtypes): 2218 acc_type = triton_acc_type(dtype) 2219 cond = " & ".join(masks) 2220 2221 value_dtype = self.cse.generate( 2222 self.compute, 2223 f"{value}.to({triton_compute_type(dtype)})", 2224 ) 2225 value = self.cse.generate( 2226 self.compute, 2227 f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})", 2228 ) 2229 broadcasted_values.append(value) 2230 2231 acc_type = triton_acc_type(dtype) 2232 cond = " & ".join(masks) 2233 2234 if not self.persistent_reduction: 2235 accumulator = self.cse.newvar() 2236 reduced_size = self.dense_size_list() 2237 reduced_size[-1] = "1" 2238 reduced_size = f"[{', '.join(reduced_size)}]" 2239 2240 default = "float('nan')" if dtype.is_floating_point else "-1" 2241 self.body.writeline( 2242 f"{accumulator} = tl.full({reduced_size}, {default}, {acc_type})" 2243 ) 2244 2245 accumulators.append(accumulator) 2246 2247 def csv(values): 2248 return " ".join(f"{value}," for value in values) 2249 2250 def cse_multiple(line, n, masks): 2251 cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] 2252 if all(cache_key in self.cse.cache for cache_key in cache_keys): 2253 return [self.cse.cache[cache_key] for cache_key in cache_keys] 2254 result_vars = [self.cse.newvar() for _ in range(n)] 2255 self.compute.writeline( 2256 f"{csv(result_vars)} = {line}", 2257 ) 2258 for result_var, cache_key in zip(result_vars, cache_keys): 2259 if masks: 2260 result_var.mask_vars = masks # type: ignore[attr-defined] 2261 self.cse.cache[cache_key] = result_var 2262 return tuple(result_vars) 2263 2264 partial_scan_vars = cse_multiple( 2265 f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})", 2266 len(values), 2267 masks, 2268 ) 2269 2270 if not self.persistent_reduction: 2271 # tl.reduce doesn't work for non-commutative operators, so instead 2272 # of repeating the scan op as a reduction, we use sum to select the 2273 # last scan value 2274 partial_reduce_vars = [ 2275 cse_compute( 2276 f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)" 2277 ) 2278 for partial_scan_var in partial_scan_vars 2279 ] 2280 accs_next = combine_fn(tuple(accumulators), tuple(partial_reduce_vars)) 2281 full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars) 2282 result_vars = [ 2283 cse_compute(f"tl.where(roffset > 0, {full_scan}, {partial_scan})") 2284 for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars) 2285 ] 2286 for acc_next, accumulator, partial_reduce in zip( 2287 accs_next, accumulators, partial_reduce_vars 2288 ): 2289 self.compute.writeline( 2290 f"{accumulator} = tl.where(roffset > 0, {acc_next}, {partial_reduce})" 2291 ) 2292 else: 2293 result_vars = partial_scan_vars 2294 2295 for result_var in result_vars: 2296 result_var.mask_vars = masks # type: ignore[attr-defined] 2297 2298 return tuple(result_vars) 2299 2300 def sort( 2301 self, 2302 dtypes: Tuple[torch.dtype, ...], 2303 values: Tuple[CSEVariable, ...], 2304 stable: bool, 2305 descending: bool, 2306 ) -> Tuple[CSEVariable, ...]: 2307 assert self.inside_reduction 2308 masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) 2309 self.filter_masks(masks) 2310 masks = sorted(masks) 2311 assert not self._load_mask, "ops.sort not supported inside ops.masked" 2312 assert ( 2313 self.persistent_reduction 2314 ), "ops.sort is only supported in persistent reductions" 2315 reduction_range_prefix = self.range_trees[-1].prefix 2316 2317 cse_compute = functools.partial(self.cse.generate, self.compute) 2318 dim = self.triton_tensor_ndim() - 1 2319 2320 broadcasted_values = [ 2321 cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})") 2322 for value in values 2323 ] 2324 2325 def csv(values): 2326 return " ".join(f"{value}," for value in values) 2327 2328 def cse_multiple(line, n, masks): 2329 cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] 2330 if all(cache_key in self.cse.cache for cache_key in cache_keys): 2331 return [self.cse.cache[cache_key] for cache_key in cache_keys] 2332 result_vars = [self.cse.newvar() for _ in range(n)] 2333 self.compute.writeline( 2334 f"{csv(result_vars)} = {line}", 2335 ) 2336 for result_var, cache_key in zip(result_vars, cache_keys): 2337 if masks: 2338 result_var.mask_vars = masks # type: ignore[attr-defined] 2339 self.cse.cache[cache_key] = result_var 2340 return tuple(result_vars) 2341 2342 assert self.range_trees[-1].prefix == "r" 2343 rnumel = "None" if self._has_constant_mask(self.range_trees[-1]) else "rnumel" 2344 2345 if len(values) == 2: 2346 line = ( 2347 f"triton_helpers.sort_with_index({broadcasted_values[0]}, {broadcasted_values[1]}," 2348 f" {rnumel}, {dim}, stable={stable}, descending={descending})" 2349 ) 2350 result_vars = cse_multiple(line, len(values), masks) 2351 else: 2352 raise AssertionError("Unhandled sort") 2353 2354 for result_var, input_var in zip(result_vars, values): 2355 result_var.mask_vars = masks # type: ignore[attr-defined] 2356 result_var.bounds = input_var.bounds 2357 2358 return tuple(result_vars) 2359 2360 def codegen_body(self): 2361 """ 2362 Concat output code from index_code, loads, compute, stores, 2363 suffix into self.body. 2364 2365 For pointwise kernels, this is called just once at the end. 2366 2367 For reduction kernels, this generates a loop over the reduction 2368 axis. 2369 """ 2370 if not ( 2371 self.indexing_code 2372 or self.loads 2373 or self.stores 2374 or self.compute 2375 or self.suffix 2376 ): 2377 return 2378 2379 if self.inside_reduction and self.range_trees[-1].is_loop: 2380 self.body.writeline("for roffset in range(0, rnumel, RBLOCK):") 2381 with self.body.indent(): 2382 # last range tree is always reduction 2383 self.iteration_ranges_codegen_header(self.range_trees[-1], self.body) 2384 self.body.splice(self.indexing_code) 2385 self.body.splice(self.loads) 2386 self.body.splice(self.compute) 2387 self.body.splice(self.stores) 2388 2389 # invalidate any caches that came from inside the reduction loop 2390 self.cse.invalidate(self.outside_loop_vars) 2391 self.range_trees[-1].cache_clear() 2392 else: 2393 self.body.splice(self.indexing_code) 2394 self.body.splice(self.loads) 2395 self.body.splice(self.compute) 2396 self.body.splice(self.stores) 2397 self.body.splice(self.suffix) 2398 self.indexing_code.clear() 2399 self.loads.clear() 2400 self.compute.clear() 2401 self.stores.clear() 2402 self.suffix.clear() 2403 2404 def codegen_kernel_benchmark(self, num_gb, grid=None): 2405 result = IndentedBuffer() 2406 argdefs, call_args, signature, _ = self.args.python_argdefs() 2407 2408 result.writelines(["", "", "def get_args():"]) 2409 with result.indent(): 2410 name_cnt = itertools.count() 2411 var_names = [] 2412 for arg_name, arg_sig in zip(call_args, signature): 2413 var_name = f"arg_{next(name_cnt)}" 2414 buf = V.graph.try_get_buffer(arg_name) 2415 if buf: 2416 result.writeline( 2417 f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long 2418 ) 2419 elif arg_name in V.graph.constants: 2420 # note that random seed is put in V.graph.constants 2421 const_tensor = V.graph.constants[arg_name] 2422 result.writeline( 2423 f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long 2424 ) 2425 elif isinstance(arg_sig, SizeArg): 2426 symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) 2427 2428 # Force the seed_offset to be 0 so calls to the same kernel 2429 # using different seed offset will have the same benchmark harness. 2430 # We can dedup kernel definitions in this case. 2431 if "seed_offset" in arg_sig.name: 2432 symval_hint = 0 2433 result.writeline(f"{var_name} = {symval_hint}") 2434 elif isinstance(arg_sig, WorkspaceArg): 2435 device = V.graph.scheduler.get_current_device_or_throw() 2436 nbytes = V.graph.sizevars.size_hint(arg_sig.nbytes) 2437 result.writeline( 2438 f"{var_name} = torch.zeros({nbytes}, device='{device}', dtype=torch.uint8)" 2439 ) 2440 else: 2441 raise KeyError( 2442 f"Don't find the buffer or const tensor for {arg_name}" 2443 ) 2444 var_names.append(var_name) 2445 result.writeline(f"return {', '.join(var_names)},") 2446 2447 result.writelines(["\n", "\n", "def call(args):"]) 2448 if grid is None: 2449 grid = [] 2450 extra_args = [] 2451 extra_args_str = None 2452 for tree in self.active_range_trees(): 2453 expr = pexpr(V.graph.sizevars.size_hint(tree.numel)) 2454 extra_args.append(expr) 2455 if tree.prefix != "r": 2456 grid.append(expr) 2457 if self.need_numel_args(): 2458 extra_args_str = ", ".join(map(str, extra_args)) + ", " 2459 else: 2460 extra_args_str = "" 2461 grid_arg = f"{extra_args_str}grid=grid({', '.join(grid)})" 2462 else: 2463 grid_arg = f"grid={grid}" 2464 current_device = V.graph.scheduler.get_current_device_or_throw() 2465 index = current_device.index 2466 with result.indent(): 2467 result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") 2468 with result.indent(): 2469 result.writeline( 2470 V.graph.device_ops.set_device(index) 2471 ) # no-op to ensure context 2472 stream_name = f"stream{index}" 2473 result.writeline(f"{stream_name} = get_raw_stream({index})") 2474 result.writeline( 2475 f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})" 2476 ) 2477 2478 # benchmark all configs 2479 result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) 2480 with result.indent(): 2481 result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") 2482 with result.indent(): 2483 result.writeline( 2484 V.graph.device_ops.set_device(index) 2485 ) # no-op to ensure context 2486 result.writeline( 2487 f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})" 2488 ) 2489 2490 result.writelines(["\n", "\n", "if __name__ == '__main__':"]) 2491 with result.indent(): 2492 result.writeline( 2493 "from torch._inductor.runtime.benchmarking import benchmarker" 2494 ) 2495 result.writeline("") 2496 2497 result.writeline("args = get_args()") 2498 result.writeline( 2499 "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)" 2500 ) 2501 result.writeline(f"num_gb = {num_gb}") 2502 result.writeline("gb_per_s = num_gb / (ms / 1e3)") 2503 result.writeline( 2504 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' 2505 ) 2506 2507 return result 2508 2509 def imports_for_benchmark_kernel(self): 2510 return textwrap.dedent( 2511 """ 2512 from torch._dynamo.testing import rand_strided 2513 {} 2514 import torch 2515 from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid 2516 """.format( 2517 V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") 2518 ) 2519 ) 2520 2521 def _get_heuristic(self): 2522 if self.persistent_reduction: 2523 assert self.inside_reduction 2524 return "persistent_reduction" 2525 elif self.inside_reduction: 2526 return "reduction" 2527 return "pointwise" 2528 2529 @staticmethod 2530 def inductor_meta_common(): 2531 inductor_meta = { 2532 "backend_hash": torch.utils._triton.triton_hash_with_backend(), 2533 "are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(), 2534 "assert_indirect_indexing": config.assert_indirect_indexing, 2535 "autotune_local_cache": config.autotune_local_cache, 2536 "autotune_pointwise": config.triton.autotune_pointwise, 2537 "autotune_remote_cache": config.autotune_remote_cache, 2538 "force_disable_caches": config.force_disable_caches, 2539 "dynamic_scale_rblock": config.dynamic_scale_rblock, 2540 "max_autotune": config.max_autotune, 2541 "max_autotune_pointwise": config.max_autotune_pointwise, 2542 "min_split_scan_rblock": config.triton.min_split_scan_rblock, 2543 "spill_threshold": config.triton.spill_threshold, 2544 "store_cubin": config.triton.store_cubin, 2545 } 2546 if torch.version.hip is not None: 2547 inductor_meta["is_hip"] = True 2548 if config.is_fbcode(): 2549 inductor_meta["is_fbcode"] = True 2550 if config.profile_bandwidth: 2551 inductor_meta["profile_bandwidth"] = config.profile_bandwidth 2552 inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex 2553 inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output 2554 inductor_meta[ 2555 "profile_bandwidth_with_do_bench_using_profiling" 2556 ] = config.profile_bandwidth_with_do_bench_using_profiling 2557 if config.coordinate_descent_tuning: 2558 inductor_meta[ 2559 "coordinate_descent_tuning" 2560 ] = config.coordinate_descent_tuning 2561 inductor_meta[ 2562 "coordinate_descent_search_radius" 2563 ] = config.coordinate_descent_search_radius 2564 inductor_meta[ 2565 "coordinate_descent_check_all_directions" 2566 ] = config.coordinate_descent_check_all_directions 2567 return inductor_meta 2568 2569 def codegen_kernel(self, name=None): 2570 code = IndentedBuffer() 2571 2572 size_hints = [] 2573 for numel in self.numels: 2574 numel_hint = V.graph.sizevars.symbolic_hint(numel) 2575 if not isinstance(numel_hint, (int, sympy.Integer)): 2576 # This default heuristic hint was picked carefully: it is 2577 # large, to ensure that we don't shrink the block size (since 2578 # if you don't have many elements, it'd be wasteful to pick a 2579 # large block size). Since we don't know how many elements we 2580 # might have, we should be OK with some inefficiency to make 2581 # sure we handle the large case well. 8192 is the largest 2582 # block size we support, so we pick that. 2583 # 2584 # If we have a better hint for unbacked SymInts (e.g., because 2585 # a user told us, or we are tracking upper bounds) we could 2586 # use that here. 2587 size_hint = 8192 2588 else: 2589 size_hint = next_power_of_2(int(numel_hint)) 2590 size_hints.append(size_hint) 2591 2592 if not self.inside_reduction: 2593 size_hints.pop() 2594 2595 heuristics = self._get_heuristic() 2596 2597 if name is None: 2598 code.splice(gen_common_triton_imports()) 2599 2600 if config.benchmark_kernel: 2601 code.splice(self.imports_for_benchmark_kernel()) 2602 2603 argdefs, _, signature, _ = self.args.python_argdefs() 2604 # maps actual expression to SizeArg if it is in sizevars replacements 2605 for i, arg in enumerate(signature): 2606 if isinstance(arg, SizeArg): 2607 # mypy is unhappy about the sympy.Expr 2608 # type for the key of the dict below 2609 symbol = cast(sympy.Symbol, arg.expr) 2610 if symbol in V.graph.sizevars.inv_precomputed_replacements: 2611 signature[i] = SizeArg( 2612 arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] 2613 ) 2614 2615 mutated_args: OrderedSet[str] = OrderedSet() 2616 for mutation in self.mutations: 2617 if mutation in self.args.input_buffers: 2618 mutated_args.add(self.args.input_buffers[mutation]) 2619 if ( 2620 mutation in self.args.inplace_buffers 2621 and mutation not in V.graph.removed_buffers 2622 and mutation not in self.removed_buffers 2623 ): 2624 mutated_args.add(self.args.inplace_buffers[mutation].inner_name) 2625 if mutation in self.args.output_buffers: 2626 mutated_args.add(self.args.output_buffers[mutation]) 2627 2628 # workspace arguments are mutated, but are not marked as mutations in self.mutations 2629 # because their buffers are added during codegen, and aren't tracked during 2630 # lowering/scheduling. So we add them as mutated_args explicitly below. 2631 # 2632 # In the logic below, we only mark the workspaces a mutated if they are marked with 2633 # zero_fill: that's because, if we don't expect the buffer to be pre-filled with 2634 # zeros, then, although we still mutate the data, we don't care about those 2635 # mutations because we don't make any assumptions about the contents of the 2636 # workspace buffer. 2637 for argname, arg in zip(argdefs, signature): 2638 if isinstance(arg, WorkspaceArg) and arg.zero_fill: 2639 mutated_args.add(argname) 2640 2641 mutated_args = sorted(mutated_args) 2642 2643 triton_meta_signature = signature_to_meta( 2644 signature, size_dtype=self.index_dtype 2645 ) 2646 triton_meta = { 2647 "signature": triton_meta_signature, 2648 "device": DeviceProperties.create( 2649 V.graph.scheduler.get_current_device_or_throw() 2650 ), 2651 "constants": {}, 2652 } 2653 2654 inductor_meta = { 2655 "autotune_hints": set(self.autotune_hints), 2656 "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), 2657 "mutated_arg_names": mutated_args, 2658 "no_x_dim": self.no_x_dim, 2659 "num_load": self.num_load, 2660 "num_reduction": self.num_reduction, 2661 **self.inductor_meta_common(), 2662 } 2663 2664 num_gb = None 2665 if config.benchmark_kernel or config.profile_bandwidth: 2666 num_gb = self.estimate_kernel_num_bytes() / 1e9 2667 inductor_meta["kernel_num_gb"] = num_gb 2668 2669 for tree in self.active_range_trees(): 2670 sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) 2671 signature.append(sizearg) 2672 triton_meta_signature[len(argdefs)] = signature_of( 2673 sizearg, size_dtype=self.index_dtype 2674 ) 2675 argdefs.append(f"{tree.prefix}numel") 2676 # constexpr version causes issues, see 2677 # https://github.com/pytorch/torchdynamo/pull/1362 2678 # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( 2679 # tree.numel 2680 # ) 2681 # argdefs.append(f"{tree.prefix}numel: tl.constexpr") 2682 triton_meta["configs"] = [config_of(signature)] 2683 2684 # Triton compiler includes equal_to_1 args into constants even 2685 # when they are not constexpr. otherwise there may be a segfault 2686 # during launching the Inductor-compiled Triton kernel. 2687 # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 2688 # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 2689 for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] 2690 triton_meta["constants"][arg_num] = 1 # type: ignore[index] 2691 2692 self.triton_meta = triton_meta 2693 2694 for tree in self.range_trees: 2695 if tree.prefix == "r" and self.persistent_reduction: 2696 # RBLOCK for persistent_reduction is defined in codegen_static_numels 2697 continue 2698 if tree.tensor_dim is None: 2699 continue 2700 argdefs.append(f"{tree.prefix.upper()}BLOCK : tl.constexpr") 2701 2702 self.codegen_body() 2703 2704 for helper in self.helper_functions: 2705 code.writeline("") 2706 code.splice(helper) 2707 2708 if self.inside_reduction: 2709 reduction_hint = self.reduction_hint 2710 heuristics_line = f""" 2711 @triton_heuristics.{heuristics}( 2712 size_hints={size_hints!r}, 2713 reduction_hint={reduction_hint}, 2714 filename=__file__, 2715 triton_meta={triton_meta!r}, 2716 inductor_meta={inductor_meta!r} 2717 ) 2718 @triton.jit 2719 """ 2720 else: 2721 tile_hint = "" 2722 if len(size_hints) == 2: 2723 if len(signature) == 4: # input, output and 2 args 2724 tile_hint = "tile_hint=TileHint.SQUARE," 2725 else: 2726 tile_hint = "tile_hint=TileHint.DEFAULT," 2727 heuristics_line = f""" 2728 @triton_heuristics.{heuristics}( 2729 size_hints={size_hints!r}, {tile_hint} 2730 filename=__file__, 2731 triton_meta={triton_meta!r}, 2732 inductor_meta={inductor_meta!r}, 2733 min_elem_per_thread={self.min_elem_per_thread} 2734 ) 2735 @triton.jit 2736 """ 2737 code.splice(heuristics_line) 2738 code.writeline( 2739 f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" 2740 ) 2741 with code.indent(): 2742 self.codegen_static_numels(code) 2743 for old, new in self.args.aliases(): 2744 code.writeline(f"{old} = {new}") 2745 code.splice(self.body) 2746 2747 if config.benchmark_kernel: 2748 code.splice(self.codegen_kernel_benchmark(num_gb)) 2749 2750 return code.getvalue() 2751 2752 def _get_persistent_RBLOCK(self, rnumel): 2753 rnumel = V.graph.sizevars.simplify(rnumel) 2754 if isinstance(rnumel, (sympy.Integer, int)): 2755 val = int(rnumel) 2756 val = next_power_of_2(val) 2757 else: 2758 val = 128 2759 while not V.graph.sizevars.statically_known_leq(rnumel, val): 2760 assert val <= 16 * 1024, f"Failed to find static RBLOCK for {rnumel}" 2761 val *= 2 2762 return val 2763 2764 def codegen_static_numels(self, code): 2765 """ 2766 We get a small speedup from hard coding numels if they are static. 2767 2768 This code stomps on the passed-in values by writing an constant to the top of the kernel. 2769 2770 In a kernel like: 2771 def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): 2772 2773 We would add 2774 xnumel = 4096 2775 rnumel = 768 2776 2777 After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes 2778 a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream 2779 knows that its a static numel, as that you just plop a constant into the kernel. 2780 """ 2781 for tree in self.range_trees: 2782 if tree.prefix != "r" or self.inside_reduction: 2783 simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) 2784 if isinstance(simplified_tree_numel, (sympy.Integer, int)): 2785 code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") 2786 2787 if tree.prefix == "r" and self.persistent_reduction: 2788 val = self._get_persistent_RBLOCK(tree.numel) 2789 code.writeline(f"RBLOCK: tl.constexpr = {val}") 2790 2791 if tree.prefix == "x" and self.no_x_dim: 2792 code.writeline("XBLOCK: tl.constexpr = 1") 2793 2794 def _get_grid_fn(self): 2795 return "grid" 2796 2797 def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid): 2798 # TODO(jansel): if there are constants, we shouldn't bother passing them as args 2799 for tree in self.range_trees: 2800 if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)): 2801 expr = tree.numel 2802 else: 2803 expr = V.graph.wrapper_code.generate_numel_expr(name, tree) 2804 2805 if tree.prefix != "r" or self.inside_reduction: 2806 call_args.append(expr) 2807 arg_types.append(type(expr)) 2808 if tree.grid_dim is not None: 2809 grid.append(expr) 2810 2811 def call_kernel(self, name: str, node: Optional[IRNode] = None): 2812 wrapper = V.graph.wrapper_code 2813 wrapper.write_triton_header_once() 2814 _, call_args, _, arg_types = self.args.python_argdefs() 2815 grid: List[Any] = [] 2816 self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid) 2817 current_device = V.graph.scheduler.get_current_device_or_throw() 2818 2819 if self.args.workspace_arg is not None: 2820 ws = self.args.workspace_arg 2821 wrapper.generate_workspace_allocation( 2822 ws.nbytes, current_device, ws.zero_fill 2823 ) 2824 2825 grid = wrapper.generate_default_grid(name, grid) 2826 wrapper.generate_kernel_call( 2827 name, 2828 call_args, 2829 grid, 2830 current_device.index, 2831 cuda=True, 2832 triton=True, 2833 arg_types=arg_types, 2834 grid_fn=self._get_grid_fn(), 2835 triton_meta=self.triton_meta, 2836 ) 2837 2838 if self.args.workspace_arg is not None: 2839 wrapper.writeline(wrapper.make_free_by_names(["workspace"])) 2840 2841 def codegen_nan_check(self): 2842 wrapper = V.graph.wrapper_code 2843 _, call_args, arg_signatures, _ = self.args.python_argdefs() 2844 for arg, arg_signature in zip(call_args, arg_signatures): 2845 if isinstance(arg_signature, TensorArg): 2846 if V.graph.cpp_wrapper: 2847 if config.abi_compatible: 2848 wrapper.writeline( 2849 f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));' 2850 ) 2851 else: 2852 wrapper.writeline(f'assert_inf_and_nan("{arg}", {arg});') 2853 else: 2854 line = f"assert not {arg}.isnan().any().item()" 2855 wrapper.writeline(line) 2856 line = f"assert not {arg}.isinf().any().item()" 2857 wrapper.writeline(line) 2858 2859 def create_cse_var(self, *args, **kwargs): 2860 return TritonCSEVariable(*args, **kwargs) 2861 2862 def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): 2863 line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}" 2864 if entry.root.is_loop: 2865 self.indexing_code.writeline(line) 2866 else: 2867 # lift non-reduction stores outside loop 2868 self.body.writeline(line) 2869 2870 def iteration_ranges_ranges_code(self, entry): 2871 assert entry.tensor_dim is not None 2872 size = self.indexing_size_str(entry.tensor_dim) 2873 index_dtype = self.index_dtype 2874 convert = f".to({index_dtype})" if index_dtype != "tl.int32" else "" 2875 return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{convert}" 2876 2877 def iteration_ranges_scalar_code(self, entry, value): 2878 index_dtype = self.index_dtype 2879 ndim = self.triton_tensor_ndim() 2880 size = [1] * ndim 2881 return f"tl.full({size}, {value}, {index_dtype})" 2882 2883 def iteration_ranges_get_pid(self, entry): 2884 assert entry.grid_dim is not None 2885 key = f"tl.program_id({entry.grid_dim})" 2886 # y_grid has a limit, so express it in terms of y and z in case of overflow. 2887 # z grid is only exercised when max_tiles == 3 (off by default). 2888 if ( 2889 entry.grid_dim == 1 2890 and not entry.has_zdim 2891 and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid()) 2892 ): 2893 # For ynumel larger than max_ygrid, we need to use zdim. 2894 # For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z). 2895 # So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset. 2896 key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))" 2897 pid = entry.pid_cache.get(key, key) 2898 if self.index_dtype != "tl.int32": 2899 return f"{pid}.to({self.index_dtype})" 2900 return pid 2901 2902 def _has_constant_mask(self, tree: IterationRangesRoot): 2903 if not self.optimize_mask: 2904 return False 2905 if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type] 2906 return True 2907 # Masks are superfluous if numel is a multiple of BLOCK 2908 # (We use the fact that BLOCK is required by triton to be a power of 2) 2909 if tree.prefix == "r" and self.persistent_reduction: 2910 max_block = self._get_persistent_RBLOCK(tree.numel) 2911 elif tree.prefix == "x" and self.no_x_dim: 2912 max_block = 1 2913 else: 2914 if tree.prefix.upper() not in TRITON_MAX_BLOCK: 2915 return False 2916 max_block = TRITON_MAX_BLOCK[tree.prefix.upper()] 2917 2918 # Optional optimization: if block divides numel exactly, we will 2919 # never need to do a masked load to handle stragglers at the end. 2920 # It's faster to avoid masking at all. But it is sound to always 2921 # mask. 2922 return V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block) 2923 2924 def filter_masks(self, mask_vars): 2925 for tree in self.range_trees: 2926 if self._has_constant_mask(tree): 2927 mask_vars.discard(f"{tree.prefix}mask") 2928 2929 def iteration_ranges_codegen_header(self, entry, code): 2930 x = entry.prefix 2931 if entry.is_loop: 2932 code.writeline(f"{entry.name} = {x}offset + {x}base") 2933 elif entry.grid_dim is None: 2934 # no need to "{x}offset = " 2935 code.writeline(f"{entry.name} = {self.iteration_ranges_ranges_code(entry)}") 2936 code.writeline(f"{x}offset = 0") 2937 else: 2938 if entry.tensor_dim is not None: 2939 line = f"{x}offset + {self.iteration_ranges_ranges_code(entry)}" 2940 else: 2941 line = self.iteration_ranges_scalar_code(entry, f"{x}offset") 2942 code.writelines( 2943 [ 2944 f"{x}offset = {self.iteration_ranges_get_pid(entry)} * {x.upper()}BLOCK", 2945 f"{entry.name} = {line}", 2946 ] 2947 ) 2948 2949 if self._has_constant_mask(entry): 2950 sizes = self.dense_size_str() 2951 code.writeline(f"{x}mask = tl.full({sizes}, True, tl.int1)") 2952 else: 2953 code.writeline(f"{x}mask = {entry.name} < {x}numel") 2954 2955 2956class TritonScheduling(SIMDScheduling): 2957 int32_type = "tl.int32" 2958 int64_type = "tl.int64" 2959 kernel_type = TritonKernel 2960 backend_features = dict.fromkeys( # dict for deterministic order 2961 [ 2962 BackendFeature.FOREACH, 2963 BackendFeature.BUCKETIZE, 2964 BackendFeature.INPLACE_BUFFERS, 2965 BackendFeature.MASKED_SCATTER_WITH_INDEX, 2966 BackendFeature.SCAN, 2967 BackendFeature.TRITON_TEMPLATES, 2968 ] 2969 ) 2970 if torch.version.hip is None: 2971 backend_features.update( 2972 dict.fromkeys( 2973 [ 2974 # TODO: Move this above when ROCm triton adds support for multiple inputs 2975 BackendFeature.TUPLE_REDUCTION, 2976 BackendFeature.SORT, 2977 ] 2978 ) 2979 ) 2980 2981 @classmethod 2982 def get_backend_features(cls, device: torch.device): 2983 return cls.backend_features 2984 2985 def codegen_comment(self, node_schedule): 2986 wrapper = V.graph.wrapper_code 2987 origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) 2988 if origins: 2989 wrapper.writeline(origins) 2990 2991 if config.debug_fusion: 2992 from torch._inductor.scheduler import ( 2993 BaseSchedulerNode, 2994 ForeachKernelSchedulerNode, 2995 ) 2996 2997 if not any( 2998 isinstance(n, ForeachKernelSchedulerNode) for n in node_schedule 2999 ): 3000 # We probably should look what are the nodes inside a foreach 3001 # schedule node 3002 node_names = [ 3003 n.get_name() 3004 for n in node_schedule 3005 if isinstance(n, BaseSchedulerNode) 3006 ] 3007 wrapper.writeline( 3008 f"{wrapper.comment} Fused node name list: {', '.join(node_names)}" 3009 ) 3010 3011 def define_kernel(self, src_code, node_schedule, kernel): 3012 wrapper = V.graph.wrapper_code 3013 if src_code in wrapper.src_to_kernel: 3014 kernel_name = wrapper.src_to_kernel[src_code] 3015 else: 3016 fused_name = ( 3017 get_fused_kernel_name(node_schedule, config.triton.descriptive_names) 3018 if config.triton.descriptive_names 3019 else "" 3020 ) 3021 kernel_category = get_kernel_category_by_source_code(src_code)[:3] 3022 kernel_name = "_".join( 3023 ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] 3024 ) 3025 # use the original src_code as the key 3026 wrapper.src_to_kernel[src_code] = kernel_name 3027 subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" 3028 3029 # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name 3030 # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set 3031 # to "triton_" to maximize caching opportunities (when unique_kernel_names = False). 3032 src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) 3033 src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name) 3034 3035 # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does 3036 # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. 3037 src_code = src_code.replace("#pragma CMT", "#") 3038 3039 basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py") 3040 3041 compile_wrapper = IndentedBuffer() 3042 compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") 3043 compile_wrapper.splice(src_code, strip=True) 3044 current_device = V.graph.scheduler.get_current_device_or_throw() 3045 compile_wrapper.writeline(f"''', device_str='{current_device.type}')") 3046 3047 metadata_comment = f"# kernel path: {kernel_path}" 3048 origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) 3049 metadata_comment += "\n" + origins + "\n" + detailed_origins 3050 wrapper.define_kernel( 3051 kernel_name, compile_wrapper.getvalue(), metadata_comment 3052 ) 3053 3054 # log kernel metadata for offline analysis. 3055 # E.g. one can find all unaligned inner reduction and check if 3056 # padding helps with the perf kernel by kernel. 3057 if is_metric_table_enabled("kernel_metadata"): 3058 log_kernel_metadata(kernel_name, kernel_path, src_code) 3059 3060 return kernel_name 3061 3062 def benchmark_fused_nodes(self, nodes): 3063 with preserve_rng_state(): 3064 src_code = self.generate_kernel_code_from_nodes( 3065 nodes, benchmark_kernel=True 3066 ) 3067 mod = PyCodeCache.load(src_code) 3068 3069 def cache_file_path(): 3070 assert mod.__file__ is not None 3071 return os.path.splitext(mod.__file__)[0] + ".kernel_perf" 3072 3073 def load_cache(): 3074 path = cache_file_path() 3075 if os.path.exists(path): 3076 with open(path) as fd: 3077 return float(fd.read()) 3078 return None 3079 3080 def store_cache(): 3081 path = cache_file_path() 3082 with open(path, "w") as fd: 3083 fd.write(str(ms)) 3084 3085 log.debug( 3086 "kernel src code for %s written to: %s", 3087 {n.get_name() for n in nodes}, 3088 mod.__file__, 3089 ) 3090 ms = load_cache() 3091 if ms is not None: 3092 return ms, mod.__file__ 3093 3094 args = mod.get_args() 3095 call = mod.call 3096 wrapped_jit_function = mod.triton_ 3097 3098 # call once to trigger the compilation 3099 try: 3100 call(wrapped_jit_function.clone_args(*args)[0]) 3101 except Exception as e: 3102 log.debug( 3103 "Exception (%s) in compiling fused nodes %s", 3104 e, 3105 {n.get_name() for n in nodes}, 3106 ) 3107 ms = float("inf") 3108 store_cache() 3109 return ms, mod.__file__ 3110 3111 launchers = wrapped_jit_function.launchers 3112 assert len(launchers) == 1 3113 if launchers[0].n_spills > 0: 3114 # skip benchmarking the kernel if there are register spills 3115 ms = float("inf") 3116 else: 3117 # We have to clone the inplace updated arguments to avoid earlier calls 3118 # generating out of range indices for later calls. 3119 ms = benchmarker.benchmark_gpu( 3120 lambda: call(wrapped_jit_function.clone_args(*args)[0]) 3121 ) 3122 3123 # overhead of cloning args gives bias for fusing the kernel 3124 # in the case of mutating/in-placeable second fusion 3125 # TODO - would be better as a hook in triton do_bench that reset 3126 # the input values between benchmarking 3127 ms = ms - benchmarker.benchmark_gpu( 3128 lambda: wrapped_jit_function.clone_args(*args) 3129 ) 3130 3131 log.debug( 3132 "The fused kernel for %s took %.3f ms to run", 3133 {n.get_name() for n in nodes}, 3134 ms, 3135 ) 3136 store_cache() 3137 return ms, mod.__file__ 3138 3139 def benchmark_combo_kernel(self, node_list): 3140 def cache_file_path(): 3141 assert mod.__file__ is not None 3142 return os.path.splitext(mod.__file__)[0] + ".kernel_perf" 3143 3144 def load_cache(): 3145 path = cache_file_path() 3146 if os.path.exists(path): 3147 with open(path) as fd: 3148 return tuple(float(e) for e in fd.read().split()) 3149 return (None, None) 3150 3151 def store_cache(): 3152 path = cache_file_path() 3153 with open(path, "w") as fd: 3154 fd.write(str(ms) + " " + str(ms_clone)) 3155 3156 total_ms, file_list = 0, [] 3157 total_clone_ms = 0 3158 removed_buffers_orig = V.graph.removed_buffers 3159 V.graph.removed_buffers = OrderedSet(removed_buffers_orig) 3160 inplaced_to_remove_orig = V.graph.inplaced_to_remove 3161 V.graph.inplaced_to_remove = OrderedSet(inplaced_to_remove_orig) 3162 enable_autotune = config.combo_kernels_autotune > 0 3163 mixed_sizes = config.combo_kernel_allow_mixed_sizes > 0 3164 kernel_code_list = self.generate_combo_kernel_code( 3165 subkernel_nodes=node_list, 3166 custom_part_algorithm=True, 3167 enable_autotune=enable_autotune, 3168 mixed_sizes=mixed_sizes, 3169 only_gen_src_code=True, 3170 ) 3171 3172 for src_code, _, node_group in kernel_code_list: 3173 fused_node_lists = [node.get_nodes() for node in node_group] 3174 names = [n.get_name() for nodes in fused_node_lists for n in nodes] 3175 3176 src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") 3177 mod = PyCodeCache.load(src_code) 3178 3179 log.debug( 3180 "kernel src code for %s written to: %s", 3181 names, 3182 mod.__file__, 3183 ) 3184 ms, ms_clone = load_cache() 3185 if ms is not None: 3186 total_ms += ms 3187 total_clone_ms += ms_clone 3188 file_list.append(mod.__file__) 3189 continue 3190 3191 args = mod.get_args() 3192 call = mod.call 3193 wrapped_jit_function = mod.triton_ 3194 3195 # call once to trigger the compilation 3196 call(wrapped_jit_function.clone_args(*args)[0]) 3197 3198 launchers = wrapped_jit_function.launchers 3199 assert len(launchers) == 1 3200 if launchers[0].n_spills > 0: 3201 # skip benchmarking the kernel if there are register spills 3202 ms = ms_clone = float("inf") 3203 else: 3204 # We have to clone the inplace updated arguments to avoid earlier calls 3205 # generating out of range indices for later calls. 3206 ms = benchmarker.benchmark_gpu( 3207 lambda: call(wrapped_jit_function.clone_args(*args)[0]) 3208 ) 3209 ms_clone = benchmarker.benchmark_gpu( 3210 lambda: wrapped_jit_function.clone_args(*args)[0] 3211 ) 3212 3213 log.debug( 3214 "The fused kernel for %s took %.3f ms to run, %.3f ms to clone inputs", 3215 {n.get_name() for n in node_group}, 3216 ms, 3217 ms_clone, 3218 ) 3219 store_cache() 3220 total_ms += ms 3221 total_clone_ms += ms_clone 3222 file_list.append(mod.__file__) 3223 V.graph.removed_buffers = removed_buffers_orig 3224 V.graph.inplaced_to_remove = inplaced_to_remove_orig 3225 return total_ms, total_clone_ms, file_list 3226