1# mypy: allow-untyped-defs 2import collections 3import copy 4import dataclasses 5import inspect 6import logging 7import threading 8from collections import defaultdict 9from typing import Any, Dict, List, Optional, Union 10 11import torch 12import torch.fx as fx 13import torch.utils._pytree as pytree 14from torch import Tensor 15from torch._C import DispatchKey 16from torch._ops import HigherOrderOperator 17from torch._prims_common import clone_preserve_strides 18from torch._subclasses.fake_tensor import FakeTensorMode 19from torch.fx.experimental.proxy_tensor import ( 20 disable_proxy_modes_tracing, 21 ProxyTorchDispatchMode, 22 track_tensor_tree, 23) 24 25 26log = logging.getLogger("torch._dynamo") 27 28 29############################################################################### 30# Kernel Side Table 31 32 33# We cannot put Triton Kernels into the FX graph as the graph nodes 34# do not support arbitrary functions. 35# Use a side table. 36# We use two dicts so that fetching both the kernel and id are O(1) 37class KernelSideTable: 38 id_to_kernel: Dict[int, Any] = {} 39 kernel_to_id: Dict[Any, int] = {} 40 constant_args: Dict[int, Any] = {} 41 lock = threading.Lock() 42 43 # Returns index on the table 44 def add_kernel(self, kernel) -> int: 45 with self.lock: 46 if kernel in self.kernel_to_id: 47 return self.kernel_to_id[kernel] 48 49 idx = len(self.id_to_kernel) 50 self.id_to_kernel[idx] = kernel 51 self.kernel_to_id[kernel] = idx 52 return idx 53 54 # Returns the triton kernel at the given index 55 def get_kernel(self, idx: int): 56 # No need to lock here as fetching from dict is atomic 57 assert idx in self.id_to_kernel 58 return self.id_to_kernel[idx] 59 60 # Not every constant arg can be added to the graph. Use this side table 61 # for constant args. 62 def add_constant_args(self, args) -> int: 63 with self.lock: 64 idx = len(self.constant_args) 65 self.constant_args[idx] = args 66 return idx 67 68 # Returns the constant args 69 def get_constant_args(self, idx: int): 70 # No need to lock here as fetching from dict is atomic 71 assert idx in self.constant_args 72 return self.constant_args[idx] 73 74 # Resets the table (only meant to be used in unit tests) 75 # This is only safe assuming single threaded execution 76 def reset_table(self) -> None: 77 self.id_to_kernel = {} 78 self.kernel_to_id = {} 79 self.constant_args = {} 80 81 82kernel_side_table = KernelSideTable() 83 84 85############################################################################### 86# Mutation Tracker 87 88 89@dataclasses.dataclass(frozen=True) 90class Param: 91 idx: int 92 93 94@dataclasses.dataclass(frozen=True) 95class Intermediate: 96 idx: int 97 98 def fake(self): 99 return self.idx < 0 100 101 102@dataclasses.dataclass(frozen=True) 103class Op: 104 name: str 105 fn_call_name: Optional[str] 106 args: List[Union[Param, Intermediate]] 107 ret: Intermediate = dataclasses.field(repr=False) 108 109 def __post_init__(self): 110 if self.name == "tt.call": 111 assert self.fn_call_name is not None 112 else: 113 assert self.fn_call_name is None 114 115 116def generate_ttir(kernel, kwargs): 117 """ 118 Uses Triton's internal code generation to create TTIR 119 """ 120 import sympy 121 import triton 122 from triton.compiler.compiler import ASTSource 123 from triton.runtime.autotuner import Autotuner 124 from triton.runtime.jit import JITFunction 125 126 import torch 127 import torch._inductor.ir 128 from torch._subclasses.fake_tensor import FakeTensor 129 130 if isinstance(kernel, Autotuner): 131 if len(kernel.configs) > 0: 132 # If we are autotuning, then it doesn't matter which version gets 133 # picked for tracing purposes, so lets pick the first one 134 kwargs = {**kwargs, **kernel.configs[0].kwargs} 135 kernel = kernel.fn 136 137 assert isinstance(kernel, JITFunction) 138 139 if len(kwargs) != len(kernel.arg_names): 140 raise ValueError("Incorrect number of arguments passed to kernel") 141 142 # Replace all SymExprs with a regular value for TTIR generation 143 # Replace all FakeTensor/TensorBox with real tensors 144 # These replacements are needed for triton's type, key and config functions 145 ordered_args: Dict[str, Any] = {} 146 for name in kernel.arg_names: 147 a = kwargs[name] 148 if isinstance(a, (torch.SymInt, torch.SymFloat, torch.SymBool, sympy.Expr)): 149 ordered_args[name] = 2 150 elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)): 151 with torch._C._DisableTorchDispatch(): 152 ordered_args[name] = torch.empty(2, dtype=a.dtype) 153 else: 154 ordered_args[name] = a 155 156 ordered_tensor_names = [ 157 name for name, arg in ordered_args.items() if isinstance(arg, Tensor) 158 ] 159 specialization = kernel._get_config(*ordered_args.values()) 160 constants = { 161 name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor) 162 } 163 164 # Build kernel signature -- doesn't include constexpr arguments. 165 signature = { 166 name: kernel._type_of(kernel._key_of(arg)) 167 for i, (name, arg) in enumerate(ordered_args.items()) 168 if i not in kernel.constexprs 169 } 170 171 context = triton._C.libtriton.ir.context() 172 target = triton.runtime.driver.active.get_current_target() 173 backend = triton.compiler.compiler.make_backend(target) 174 options = backend.parse_options({}) 175 triton._C.libtriton.ir.load_dialects(context) 176 backend.load_dialects(context) 177 178 src = ASTSource(kernel, signature, constants, specialization) 179 180 # Triton changes ASTSource.make_ir to take 3/4 arguments. Handle 181 # backward compatibility here. 182 make_ir_sig_params = len(inspect.signature(src.make_ir).parameters) 183 if make_ir_sig_params == 2: 184 ttir_module = src.make_ir(options, context) 185 elif make_ir_sig_params == 3: 186 codegen_fns = backend.get_codegen_implementation() 187 ttir_module = src.make_ir(options, codegen_fns, context) 188 else: 189 codegen_fns = backend.get_codegen_implementation() 190 module_map = backend.get_module_map() 191 ttir_module = src.make_ir(options, codegen_fns, module_map, context) 192 if not ttir_module.verify(): 193 raise RuntimeError("Verification for TTIR module has failed") 194 195 return ttir_module, ordered_tensor_names 196 197 198def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]: 199 """ 200 Walk the `ttir_module` bottom up to mine the `functions` from 201 the structured MLIR entities representing the Triton kernel 202 (mlir::Operation, mlir::Block, mlir::Region). 203 """ 204 functions: Dict[str, Dict[Intermediate, List[Op]]] = {} 205 206 # block id --> op result (Intermediate) --> one or more ops 207 op_stack: Dict[int, Dict[Intermediate, List[Op]]] = defaultdict( 208 lambda: defaultdict(list) 209 ) 210 region_id_to_block_ids: Dict[int, List[int]] = defaultdict(list) 211 block_id_to_block_arg_ids: Dict[int, List[int]] = {} 212 replacements: Dict[int, Union[Intermediate, Param]] = {} 213 reindex_map: Dict[int, int] = {} 214 next_fake_intermediate = 0 215 216 def reindex(idx): 217 if idx not in reindex_map: 218 reindex_map[idx] = len(reindex_map) 219 return reindex_map[idx] 220 221 def mlir_to_functions(op) -> None: 222 name: str = op.get_name() 223 if name == "builtin.module": 224 # this wraps all tt.func ops 225 return 226 227 operand_ids: List[int] = [ 228 reindex(op.get_operand(i).id()) for i in range(op.get_num_operands()) 229 ] 230 result_ids: List[int] = [ 231 reindex(op.get_result(i).id()) for i in range(op.get_num_results()) 232 ] 233 234 child_block_ids: List[int] = [] 235 for i in [op.get_region(i).id() for i in range(op.get_num_regions())]: 236 # as the walk is bottom-up, the region_id_to_block_ids[i] 237 # must be populated by the time we process the enclosing op 238 child_block_ids.extend(region_id_to_block_ids[i]) 239 240 parent_block_id = -1 241 parent_block = op.get_block() 242 if parent_block is not None: 243 parent_block_id = parent_block.id() 244 if parent_block_id not in block_id_to_block_arg_ids: 245 block_id_to_block_arg_ids[parent_block_id] = [] 246 for i in range(parent_block.get_num_arguments()): 247 block_id_to_block_arg_ids[parent_block_id].append( 248 reindex(parent_block.get_argument(i).id()), 249 ) 250 # the region info is collected via ops' parent blocks to be 251 # used later when the region's encloding op is traversed 252 parent_region = parent_block.get_parent() 253 if parent_region is not None: 254 region_id_to_block_ids[parent_region.id()].append(parent_block_id) 255 256 nonlocal next_fake_intermediate 257 258 if name == "tt.func": 259 # for function ops: gather and inline 260 # the ops from all child blocks 261 fn_ops = defaultdict(list) 262 for child_block_id in child_block_ids: 263 for result, block_fn_ops in op_stack.pop(child_block_id).items(): 264 for block_fn_op in block_fn_ops: 265 fn_ops[result].append(block_fn_op) 266 267 # replace the corresponding Intermediates in the 268 # child op args with the function args (Params) 269 for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]): 270 replacements[idx] = Param(i) 271 272 for fn_op_list in fn_ops.values(): 273 for fn_op in fn_op_list: 274 for i in range(len(fn_op.args)): 275 arg = fn_op.args[i] 276 seen = set() # to break cycles 277 # there can be transitive replacements, but likely 278 # no cycles (we keep the `seen` set just in case) 279 while ( 280 isinstance(arg, Intermediate) 281 and arg.idx in replacements 282 and arg.idx not in seen 283 ): 284 seen.add(arg.idx) 285 arg = fn_op.args[i] = replacements[arg.idx] 286 287 # next function capture starts 288 # with empty replacements 289 replacements.clear() 290 291 fn_name = op.get_str_attr("sym_name") 292 functions[fn_name] = fn_ops 293 elif child_block_ids: 294 if name in {"scf.if", "scf.for", "scf.while", "tt.reduce", "tt.scan"}: 295 # for blocked ops: inline the enclosed ops into 296 # the parent block + rewire the last op in each 297 # child block to return the block result 298 return_ops = [] 299 for block_id in child_block_ids: 300 if name == "scf.for": 301 # example: 302 # %result = scf.for %iv = %lb to %ub step %step iter_args(%arg = %init) -> (i32) ... 303 # block args: 2 (%iv, %arg) 304 # op operands: 4 (%lb, %ub, %step, %init) 305 # `%arg` is mapping to `%init` 306 for i, idx in enumerate(block_id_to_block_arg_ids[block_id]): 307 if i == 0: 308 next_fake_intermediate -= 1 309 replacements[idx] = Intermediate(next_fake_intermediate) 310 else: 311 replacements[idx] = Intermediate(operand_ids[i + 2]) 312 elif name == "scf.while": 313 # example: 314 # %3:3 = scf.while (%arg2 = %1, %arg3 = %2, %arg4 = %c0_i32_8) ... 315 # block args: 3 (%arg2, %arg3, %arg4) 316 # op operands: 3 (%1, %2, %c0_i32_8) 317 # `%arg2` is mapping to `%1`, `%arg3` is mapping to `%2`, ... 318 for i, idx in enumerate(block_id_to_block_arg_ids[block_id]): 319 replacements[idx] = Intermediate(operand_ids[i]) 320 elif name == "scf.if": 321 # the scf block args are ignored by the pass. but, as they 322 # may be used as operands of the ops inside the block 323 # (and nested blocks inlined in the current block by now), 324 # they are replaced by new fake Intermediates to avoid "this 325 # operand is not returned by any other op in the fn" error 326 # in the downstream analysis 327 for idx in block_id_to_block_arg_ids[block_id]: 328 next_fake_intermediate -= 1 329 replacements[idx] = Intermediate(next_fake_intermediate) 330 else: 331 assert name in ("tt.reduce", "tt.scan") 332 # wire the block arguments to the op arguments 333 num_operands = len(operand_ids) 334 block_arg_ids = block_id_to_block_arg_ids[block_id] 335 assert len(block_arg_ids) == 2 * num_operands, ( 336 f"{name} is expected to have twice as " 337 "many block arguments as op arguments: " 338 f"{operand_ids=}, {block_arg_ids=}." 339 ) 340 for i, idx in enumerate(block_arg_ids): 341 # for a tt.reduce/tt.scan op with N arguments, the block 342 # arguments comprise N reduced values followed by 343 # N current values corresponding to the N op args 344 replacements[idx] = Intermediate( 345 operand_ids[i % num_operands] 346 ) 347 348 if block_id in op_stack: 349 block_ops = op_stack.pop(block_id) 350 if not block_ops: 351 continue 352 last_ret, last_ops = block_ops.popitem() 353 if all( 354 op.name 355 in ("scf.yield", "tt.reduce.return", "tt.scan.return") 356 for op in last_ops 357 ): 358 # if last_ops are all return ops, treat them separately 359 return_ops.extend(last_ops) 360 else: 361 # otherwise, return last_ops to the block 362 block_ops[last_ret] = last_ops 363 for op_result, child_ops in block_ops.items(): 364 op_stack[parent_block_id][op_result].extend(child_ops) 365 366 scf_results = [Intermediate(idx) for idx in result_ids] 367 for scf_result in scf_results: 368 for return_op in return_ops: 369 op_stack[parent_block_id][scf_result].append(return_op) 370 else: 371 raise RuntimeError( 372 f"Unknown blocked function: {name}. Can't capture the TTIR." 373 ) 374 else: 375 callee = None 376 if name == "tt.call": 377 callee = op.get_flat_symbol_ref_attr("callee") 378 args: List[Union[Param, Intermediate]] = [ 379 Intermediate(operand) for operand in operand_ids 380 ] 381 block_ops = op_stack[parent_block_id] 382 if result_ids: 383 for result_id in result_ids: 384 res = Intermediate(result_id) 385 block_ops[res].append(Op(name, callee, args, res)) 386 else: 387 next_fake_intermediate -= 1 388 fake_res = Intermediate(next_fake_intermediate) 389 block_ops[fake_res].append(Op(name, callee, args, fake_res)) 390 391 ttir_module.walk(mlir_to_functions) 392 393 return functions 394 395 396class MemoizeWithCycleCheck: 397 def __init__(self, fn): 398 self.fn = fn 399 self.reset() 400 401 def __call__(self, functions, fn_name, num_args): 402 key = (fn_name, num_args) 403 if key not in self.cache: 404 self.cache[key] = None 405 self.cache[key] = self.fn(functions, fn_name, num_args) 406 if self.cache[key] is None: 407 raise RuntimeError("Recursion is not supported") 408 return self.cache[key] 409 410 def reset(self): 411 self.cache = {} 412 413 414@MemoizeWithCycleCheck 415def analyze_kernel_mutations(functions, fn_name, num_args): 416 """ 417 Analyzes the graph to detect all sinks from a predefined list of sinks 418 by using triton's MemWrite trait list. NOTE: What if triton exposed this? 419 From each sink, it traverses the CFG backwards to identify all the input 420 pointers that are mutated. 421 """ 422 # Name of mutation op to mutated parameter indices 423 # List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td 424 # All the OPs that have MemWrite trait. 425 # What if Triton exposed this? 426 MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]} 427 # Ops that we want to bail out on 428 UNKNOWN_OPS = {"tt.elementwise_inline_asm"} 429 430 stack: List[Union[Param, Intermediate]] = [] 431 visited = set() 432 ops = functions[fn_name] 433 for op_list in ops.values(): 434 for op in op_list: 435 if op.name in UNKNOWN_OPS: 436 raise RuntimeError( 437 f"ttir analysis hit an op we do not know how to analyze: {op.name}" 438 ) 439 440 if op.name == "tt.call": 441 assert op.fn_call_name in functions 442 mutations = analyze_kernel_mutations( 443 functions, op.fn_call_name, len(op.args) 444 ) 445 stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated) 446 else: 447 for idx in MUTATION_OPS.get(op.name, []): 448 stack.append(op.args[idx]) 449 450 # The following is an iterative DFS algorithm 451 mutated = [False] * num_args 452 while stack: 453 arg = stack.pop() 454 if arg in visited: 455 continue 456 457 visited.add(arg) 458 459 if isinstance(arg, Param): 460 if arg.idx >= num_args: 461 # This is an argument defined in the kernel, not passed in 462 continue 463 mutated[arg.idx] = True 464 elif isinstance(arg, Intermediate) and not arg.fake(): 465 for op in ops[arg]: 466 # Skip arguments to load 467 if op.name != "tt.load": 468 stack.extend(op.args) 469 return mutated 470 471 472def identify_mutated_tensors(kernel, kwargs): 473 """ 474 Given a triton kernel and the arguments for this kernel, this function 475 1) Retrieves the TTIR converted version of the kernel from Triton's API. 476 2) Parses the TTIR and creates a control flow graph 477 3) Analyzes the graph to detect all input tensor mutations 478 """ 479 480 ttir_module = None 481 functions = None 482 try: 483 ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs) 484 485 # extract functions from TTIR using MLIR bindings exposed by Triton code 486 functions = ttir_to_functions(ttir_module) 487 488 assert functions is not None 489 kernel_name = next(iter(functions.keys())) 490 # Triton codegen modifies the name 491 assert kernel.fn.__name__ in kernel_name 492 # Reset the cache between top level invocations 493 # The cache for analyze kernel mutations is mainly used for cycle 494 # detection, so each top level invocation needs a clean cache 495 analyze_kernel_mutations.reset() 496 mutations = analyze_kernel_mutations( 497 functions, kernel_name, len(ordered_tensor_names) 498 ) 499 500 return [ 501 ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated 502 ] 503 except Exception as e: 504 log.warning( 505 "Encountered an exception in identify_mutated_tensors, assuming every input is mutated", 506 exc_info=True, 507 ) 508 if ttir_module is not None: 509 log.debug("TTIR:\n%s", str(ttir_module)) 510 if functions is not None: 511 log.debug("functions:") 512 for name, fn in functions.items(): 513 log.debug("===\t%s\t===", name) 514 for ret, ops in fn.items(): 515 log.debug("%s\t=>\t%s", ret, ops) 516 return [key for key, value in kwargs.items() if isinstance(value, Tensor)] 517 518 519############################################################################### 520# Triton Kernel Wrappers 521 522 523# Used for wrapping a Triton Kernel 524class TritonKernelWrapperMutation(HigherOrderOperator): 525 def __init__(self) -> None: 526 super().__init__("triton_kernel_wrapper_mutation") 527 528 def __call__(self, kernel_idx, constant_args_idx, grid, kwargs): 529 return super().__call__( 530 kernel_idx=kernel_idx, 531 constant_args_idx=constant_args_idx, 532 grid=grid, 533 kwargs=kwargs, 534 ) 535 536 537triton_kernel_wrapper_mutation = TritonKernelWrapperMutation() 538 539 540# Used for wrapping a Triton Kernel in a functional manner 541class TritonKernelWrapperFunctional(HigherOrderOperator): 542 def __init__(self) -> None: 543 super().__init__("triton_kernel_wrapper_functional") 544 545 def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone): 546 return super().__call__( 547 kernel_idx=kernel_idx, 548 constant_args_idx=constant_args_idx, 549 grid=grid, 550 kwargs=kwargs, 551 tensors_to_clone=tensors_to_clone, 552 ) 553 554 555triton_kernel_wrapper_functional = TritonKernelWrapperFunctional() 556 557 558@triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd) 559def triton_kernel_wrapper_mutation_dense( 560 *, kernel_idx, constant_args_idx, grid, kwargs 561): 562 from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code 563 564 kernel = kernel_side_table.get_kernel(kernel_idx) 565 constant_args = kernel_side_table.get_constant_args(constant_args_idx) 566 567 if len(grid) == 1: 568 grid_fn = grid[0] 569 else: 570 fn_name, code = user_defined_kernel_grid_fn_code( 571 kernel.fn.__name__, kernel.configs, grid 572 ) 573 namespace: Dict[str, Any] = {} 574 exec(code, namespace) 575 grid_fn = namespace[fn_name] 576 577 kernel[grid_fn](**kwargs, **constant_args) 578 579 580@triton_kernel_wrapper_mutation.py_impl(FakeTensorMode) 581def triton_kernel_wrapper_mutation_fake_tensor_mode( 582 mode, *, kernel_idx, constant_args_idx, grid, kwargs 583): 584 with mode: 585 return None 586 587 588@triton_kernel_wrapper_mutation.py_impl(DispatchKey.Meta) 589def _(*, kernel_idx, constant_args_idx, grid, kwargs): 590 return None 591 592 593def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args): 594 with disable_proxy_modes_tracing(): 595 out = func_overload(**node_args) 596 597 proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) 598 out_proxy = proxy_mode.tracer.create_proxy( 599 "call_function", 600 func_overload, 601 (), 602 proxy_args, 603 name=func_overload.__name__ + "_proxy", 604 ) 605 ret = track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) 606 return ret 607 608 609@triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode) 610def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( 611 mode, *, kernel_idx, constant_args_idx, grid, kwargs 612): 613 trace_triton_kernel_wrapper( 614 mode, 615 triton_kernel_wrapper_mutation, 616 { 617 "kernel_idx": kernel_idx, 618 "constant_args_idx": constant_args_idx, 619 "grid": grid, 620 "kwargs": kwargs, 621 }, 622 ) 623 624 return None 625 626 627@triton_kernel_wrapper_mutation.py_functionalize_impl 628def triton_kernel_wrapper_mutation_functionalize( 629 ctx, kernel_idx, constant_args_idx, grid, kwargs 630): 631 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) 632 kernel = kernel_side_table.get_kernel(kernel_idx) 633 constant_args = kernel_side_table.get_constant_args(constant_args_idx) 634 # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each 635 # other, and one gets mutated in kernel, and later another gets mutated, 636 # they are no longer equal. Fix this by graph breaking on this condition 637 # earlier in dynamo. 638 tensors_to_clone = identify_mutated_tensors( 639 kernel, {**unwrapped_kwargs, **constant_args} 640 ) 641 with ctx.redispatch_to_next(): 642 unwrapped_outputs = triton_kernel_wrapper_functional( 643 kernel_idx=kernel_idx, 644 constant_args_idx=constant_args_idx, 645 grid=grid, 646 kwargs=unwrapped_kwargs, 647 tensors_to_clone=tensors_to_clone, 648 ) 649 650 assert set(unwrapped_outputs.keys()).issubset(set(kwargs.keys())) 651 for key, output_arg in unwrapped_outputs.items(): 652 if not isinstance(output_arg, Tensor): 653 continue 654 input_arg = kwargs[key] 655 assert isinstance(input_arg, Tensor) 656 657 ctx.replace(input_arg, output_arg) 658 # indicate that above replace is hidden from autograd 659 ctx.mark_mutation_hidden_from_autograd(input_arg) 660 ctx.commit_update(input_arg) 661 ctx.sync(input_arg) 662 return None 663 664 665@triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd) 666def triton_kernel_wrapper_functional_dense( 667 *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone 668): 669 # TODO(oulgen): For performance reasons, we want to ensure that these 670 # `clone_preserve_strides` calls are never executed at runtime 671 # (inductor should always optimize them away). 672 # Requires https://github.com/pytorch/pytorch/issues/109240 673 kwargs = { 674 key: (clone_preserve_strides(val) if key in tensors_to_clone else val) 675 for key, val in kwargs.items() 676 } 677 triton_kernel_wrapper_mutation( 678 kernel_idx=kernel_idx, 679 constant_args_idx=constant_args_idx, 680 grid=grid, 681 kwargs=kwargs, 682 ) 683 return {key: val for key, val in kwargs.items() if key in tensors_to_clone} 684 685 686@triton_kernel_wrapper_functional.py_impl(FakeTensorMode) 687def triton_kernel_wrapper_functional_fake_tensor_mode( 688 mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone 689): 690 # TODO(oulgen): For performance reasons, we want to ensure that these 691 # `clone_preserve_strides` calls are never executed at runtime 692 # (inductor should always optimize them away). 693 # Requires https://github.com/pytorch/pytorch/issues/109240 694 with mode: 695 return { 696 key: clone_preserve_strides(val) 697 for key, val in kwargs.items() 698 if key in tensors_to_clone 699 } 700 701 702@triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode) 703def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode( 704 mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone 705): 706 return trace_triton_kernel_wrapper( 707 mode, 708 triton_kernel_wrapper_functional, 709 { 710 "kernel_idx": kernel_idx, 711 "constant_args_idx": constant_args_idx, 712 "grid": grid, 713 "kwargs": kwargs, 714 "tensors_to_clone": tensors_to_clone, 715 }, 716 ) 717 718 719@triton_kernel_wrapper_functional.py_functionalize_impl 720def triton_kernel_wrapper_functional_functionalize( 721 ctx, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone 722): 723 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) 724 with ctx.redispatch_to_next(): 725 outputs = triton_kernel_wrapper_functional( 726 kernel_idx=kernel_idx, 727 constant_args_idx=constant_args_idx, 728 grid=grid, 729 kwargs=unwrapped_kwargs, 730 tensors_to_clone=tensors_to_clone, 731 ) 732 return ctx.wrap_tensors(outputs) 733 734 735triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] 736triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] 737triton_kernel_wrapper_mutation.fallthrough(DispatchKey.ADInplaceOrView) 738triton_kernel_wrapper_mutation.fallthrough(DispatchKey.BackendSelect) 739triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] 740triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] 741triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCUDA) 742triton_kernel_wrapper_mutation.fallthrough(DispatchKey.AutogradCPU) 743 744triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] 745triton_kernel_wrapper_functional.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined] 746triton_kernel_wrapper_functional.fallthrough(DispatchKey.ADInplaceOrView) 747triton_kernel_wrapper_functional.fallthrough(DispatchKey.BackendSelect) 748triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined] 749triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutocastCUDA) # type: ignore[attr-defined] 750triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) 751triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCUDA) 752triton_kernel_wrapper_functional.fallthrough(DispatchKey.AutogradCPU) 753 754 755############################################################################### 756# The "TritonHOPifier": a class that transforms a call to a triton kernel into 757# a call to the triton_kernel_wrapper_mutation HOP. 758 759 760class TritonHOPifier: 761 """Orchestrator for converting a user-defined triton kernel into a call 762 to the triton_kernel_wrapper_mutation HOP. 763 764 It has two main use cases. 765 766 1. When Dynamo sees a triton kernel, it wraps it into a TritonKernelVariable 767 and uses the TritonHOPifier to convert calls to the TritonKernelVariable 768 into a call to the HOP. 769 770 2. In order to capture a user-defined triton kernel while performing 771 tracing (via make_fx or non-strict export), a user must annotate their 772 triton kernel with the `capture_triton` decorator. The decorator uses 773 TritonHOPifier to convert calls to the triton kernel into a call 774 to the HOP (which can then be traced). 775 776 Because Dynamo has its own calling conventions for e.g. invoking a user-defined function 777 TritonHOPifier is an abstract class that can be overriden by its subclasses. 778 """ 779 780 def raise_unsupported(self, msg): 781 raise NotImplementedError("abstract method") 782 783 def is_callable(self, maybe_callable): 784 raise NotImplementedError("abstract method") 785 786 def get_value(self, val): 787 raise NotImplementedError("abstract method") 788 789 def call_grid(self, grid, meta, tx): 790 raise NotImplementedError("abstract method") 791 792 def call_HOP(self, variable, grids, combined_args, tx): 793 raise NotImplementedError("abstract method") 794 795 def check_grid(self, grid): 796 raise NotImplementedError("abstract method") 797 798 def init_variable(self, variable, kernel, kernel_idx, grid): 799 from triton.runtime.autotuner import Autotuner 800 801 assert kernel is not None 802 803 variable.kernel = kernel 804 variable.kernel_idx = kernel_side_table.add_kernel(kernel) 805 806 assert kernel_idx is None or variable.kernel_idx == kernel_idx 807 808 variable.grid = grid 809 810 if isinstance(kernel, Autotuner): 811 import torch 812 import torch._dynamo 813 814 # We only support configs and keys arguments of triton.autotune 815 # Make sure other arguments are defaulted 816 defaults = inspect.signature(Autotuner.__init__).parameters 817 818 # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep. 819 # The call to get_first_attr is to maintain backward-compatibility. 820 if ( 821 not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args 822 and ( 823 ( 824 "warmup" in defaults 825 and defaults["warmup"].default 826 != torch._dynamo.utils.get_first_attr( 827 kernel, "num_warmups", "warmup" 828 ) 829 ) 830 or ( 831 "rep" in defaults 832 and defaults["rep"].default 833 != torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep") 834 ) 835 or ( 836 "prune_configs_by" in defaults 837 and defaults["prune_configs_by"].default 838 != kernel.early_config_prune 839 ) 840 # Set via reset_to_zero argument 841 or len(kernel.reset_idx) != 0 842 or len(kernel.restore_idx) != 0 843 or ( 844 "use_cuda_graph" in defaults 845 and defaults["use_cuda_graph"].default != kernel.use_cuda_graph 846 ) 847 ) 848 ): 849 self.raise_unsupported( 850 "Only configs and keys are supported for triton.autotune" 851 ) 852 853 def call_getitem(self, variable, args): 854 # __getitem__ should only be called if we don't already have a grid 855 # Only grid needs to be passed 856 if variable.grid is not None or len(args) != 1: 857 self.raise_unsupported( 858 "Triton kernels should be called with only a single grid" 859 ) 860 861 return type(variable)( 862 kernel=variable.kernel, 863 kernel_idx=variable.kernel_idx, 864 grid=args[0], 865 ) 866 867 def call_run(self, variable, args, kwargs, tx): 868 if "grid" not in kwargs: 869 self.raise_unsupported("Triton kernel requires to be called with a grid") 870 grid = kwargs.pop("grid") 871 kwargs.pop("warmup", None) 872 # rewrite kernel.run(*args, grid=grid) to kernel[grid](*args) 873 return self.call_triton_kernel( 874 type(variable)( 875 kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=grid 876 ), 877 args, 878 kwargs, 879 tx, 880 ) 881 882 def call_triton_kernel(self, variable, args, kwargs, tx): 883 from triton.runtime.autotuner import autotune, Autotuner, Config 884 885 if "num_ctas" in kwargs: 886 self.raise_unsupported( 887 "Passing num_ctas directly to the Triton kernel is not supported. " 888 "Please use a Config in @triton.autotune instead." 889 ) 890 891 special_kwargs = {} 892 for name in ("num_warps", "num_stages"): 893 if name in kwargs: 894 # remove special kwargs from `kwargs` 895 val = kwargs.pop(name) 896 special_kwargs[name] = self.get_value(val) 897 898 if special_kwargs: 899 if isinstance(variable.kernel, Autotuner): 900 # if there is Autotuner already, set 901 # special kwargs to each of its configs 902 new_configs = copy.deepcopy(variable.kernel.configs) 903 for config in new_configs: 904 config.__dict__.update(special_kwargs) 905 new_kernel = autotune(configs=new_configs, key=[])(variable.kernel.fn) 906 else: 907 # if there is no Autotuner, wrap the kernel into a 908 # new one with a single config with special kwargs 909 new_config = Config(kwargs={}, **special_kwargs) 910 new_kernel = autotune(configs=[new_config], key=[])(variable.kernel) 911 912 # create a new variable to contain the new (wrapped) kernel; 913 # skip kernel_idx to get a new record in the kernel side table 914 new_var = type(variable)(new_kernel, None, variable.grid) 915 return self.call_triton_kernel(new_var, args, kwargs, tx) 916 917 if variable.grid is None: 918 self.raise_unsupported("Triton kernels should always be called with a grid") 919 920 # Both for grid's meta as well as for the kernel, we need combined 921 # args and kwargs combined and normalized 922 combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs} 923 924 configs = ( 925 [config.kwargs for config in variable.kernel.configs] 926 if isinstance(variable.kernel, Autotuner) 927 else [{}] 928 ) 929 grids = [] 930 for config_args in configs: 931 # If the grid is a function, then lets execute it and convert it to 932 # a list 933 grid = variable.grid 934 if self.is_callable(grid): 935 # Populate the special "meta" argument to call the grid function 936 meta = {**combined_args_raw, **config_args} 937 grid = self.call_grid(grid, meta, tx) 938 grids.append(self.check_grid(grid)) 939 940 for i in range(len(grids)): 941 if not isinstance(grids[i], tuple): 942 self.raise_unsupported("Only tuple grids are supported") 943 # inductor expects all grids to be 3-tuple so lets make it 944 if len(grids[i]) == 1: 945 grids[i] = (grids[i][0], 1, 1) 946 elif len(grids[i]) == 2: 947 grids[i] = (grids[i][0], grids[i][1], 1) 948 elif len(grids[i]) > 3: 949 self.raise_unsupported("Grid can have at most rank 3") 950 951 assert len(grids) != 0 952 953 def intify(x): 954 if isinstance(x, torch.SymInt): 955 return int(x) 956 else: 957 return x 958 959 if len(set(pytree.tree_map(intify, grids))) == 1: 960 # If there's only one unique grid, lets simplify 961 grids = [grids[0]] 962 963 return self.call_HOP(variable, grids, combined_args_raw, tx) 964 965 966############################################################################### 967# Helpers for capture_triton API that makes a user-defined triton kernel traceable into 968# a graph via make_fx or non-strict export (coming soon) 969 970 971class TracingTritonHOPifier(TritonHOPifier): 972 def raise_unsupported(self, msg): 973 raise RuntimeError(msg) 974 975 def is_callable(self, maybe_callable): 976 return callable(maybe_callable) 977 978 def get_value(self, val): 979 return val 980 981 def call_grid(self, grid, meta, tx): 982 assert tx is None 983 return grid(meta) 984 985 def check_grid(self, grid): 986 if not isinstance(grid, collections.abc.Sequence): 987 raise RuntimeError( 988 "capture_triton can only handle grids that resolve to Sequence[int]." 989 ) 990 # normalize to tuple 991 return tuple(grid) 992 993 def call_HOP(self, variable, grids, combined_args, tx): 994 assert tx is None 995 996 def is_graphable(val): 997 return isinstance(val, fx.node.base_types) 998 999 non_graphable_args = { 1000 k: v for k, v in combined_args.items() if not is_graphable(v) 1001 } 1002 graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)} 1003 1004 constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args) 1005 return triton_kernel_wrapper_mutation( 1006 kernel_idx=variable.kernel_idx, 1007 constant_args_idx=constant_args_idx, 1008 grid=grids, 1009 kwargs=graphable_args, 1010 ) 1011 1012 1013tracing_triton_hopifier_singleton = TracingTritonHOPifier() 1014 1015 1016class TraceableTritonKernelWrapper: 1017 def __init__(self, kernel, kernel_idx, grid): 1018 self.kernel = None 1019 self.grid = None 1020 tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) 1021 assert self.kernel is not None 1022 1023 def __getitem__(self, *args): 1024 return tracing_triton_hopifier_singleton.call_getitem(self, args) 1025 1026 def run(self, *args, **kwargs): 1027 from torch._library.triton import is_capture_triton_enabled 1028 1029 if is_capture_triton_enabled(): 1030 return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None) 1031 else: 1032 assert self.kernel is not None 1033 return self.kernel.run(*args, **kwargs) 1034 1035 def __call__(self, *args, **kwargs): 1036 from torch._library.triton import is_capture_triton_enabled 1037 1038 if is_capture_triton_enabled(): 1039 return tracing_triton_hopifier_singleton.call_triton_kernel( 1040 self, args, kwargs, None 1041 ) 1042 else: 1043 assert self.kernel is not None 1044 return self.kernel[self.grid](*args, **kwargs) 1045