1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import builtins 5import copy 6import functools 7import hashlib 8import inspect 9import logging 10import math 11import operator 12import os 13import os.path 14import re 15import sys 16import threading 17import time 18from typing import Any, Dict, List, Optional, Set, Tuple 19 20import torch 21 22from .autotune_cache import AutotuneCache 23from .benchmarking import benchmarker 24from .coordinate_descent_tuner import CoordescTuner 25from .hints import ( 26 _NUM_THREADS_PER_WARP, 27 AutotuneHint, 28 DeviceProperties, 29 HeuristicType, 30 ReductionHint, 31 TileHint, 32 TRITON_MAX_BLOCK, 33) 34from .runtime_utils import ( 35 cache_dir, 36 ceildiv, 37 conditional_product, 38 create_bandwidth_info_str, 39 dynamo_timed, 40 get_first_attr, 41 get_max_y_grid, 42 get_num_bytes, 43 next_power_of_2, 44 triton_config_to_hashable, 45 validate_triton_config, 46) 47 48 49try: 50 import triton 51except ImportError: 52 triton = None 53 54if triton is not None: 55 from triton import Config 56 from triton.compiler import CompiledKernel 57 from triton.runtime.autotuner import OutOfResources 58 from triton.runtime.jit import KernelInterface 59 60 try: 61 from triton.compiler.compiler import ASTSource 62 except ImportError: 63 ASTSource = None 64 65 try: 66 from triton.backends.compiler import GPUTarget 67 except ImportError: 68 GPUTarget = None 69else: 70 Config = object 71 KernelInterface = object 72 OutOfResources = object 73 ASTSource = None 74 GPUTarget = None 75 76try: 77 autograd_profiler = torch.autograd.profiler 78except AttributeError: # Compile workers only have a mock version of torch 79 80 class autograd_profiler: # type: ignore[no-redef] 81 _is_profiler_enabled = False 82 83 84log = logging.getLogger(__name__) 85 86 87def autotune_hints_to_configs( 88 hints: Set[AutotuneHint], size_hints, block_size: int 89) -> List[Config]: 90 """ 91 AutotuneHints can be attached to the metadata of triton kernels for providing 92 suggestions about what to try for autotuning. One reason to do this is if there are 93 some configs that are only useful in specific scenarios, in which case we can avoid 94 wasting compile time on autotuning unless we know we are in one of those scenarios. 95 96 Based on those hints, this function will generate a list of additional autotuning 97 configs to try. 98 """ 99 xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...] 100 configs = [] 101 102 for hint in hints: 103 if hint == AutotuneHint.ELEMENTS_PER_WARP_32: 104 if len(size_hints) == 1: 105 xyz_options = ((block_size // 4, None, None),) 106 elif len(size_hints) == 2: 107 xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None)) 108 elif len(size_hints) == 3: 109 xyz_options = ( 110 (block_size // 4, 1, 1), 111 (1, block_size // 4, 1), 112 (1, 1, block_size // 4), 113 ) 114 for xyz in xyz_options: 115 configs.append( 116 triton_config( 117 size_hints, 118 *xyz, 119 num_elements_per_warp=32, 120 ) 121 ) 122 123 return configs 124 125 126def disable_pointwise_autotuning(inductor_meta): 127 # Autotuning can give different benchmarking results from run to run, and 128 # therefore we disable autotuning when use_deterministic flag is on. 129 if inductor_meta.get("are_deterministic_algorithms_enabled"): 130 return True 131 return not inductor_meta.get("autotune_pointwise", True) 132 133 134def _dump_launch_params(args, kwargs, launcher, kernel_name): 135 call_args = [] 136 call_kwargs = {} 137 for arg in args: 138 if isinstance(arg, (int, bool)): 139 call_args.append(str(arg)) 140 else: 141 call_args.append("T") 142 for k, v in kwargs.items(): 143 if isinstance(arg, (int, bool)): 144 call_kwargs[k] = v 145 else: 146 call_kwargs[k] = v 147 for k, v in launcher.config.kwargs.items(): 148 call_kwargs[k] = v 149 call_kwargs["num_warps"] = launcher.config.num_warps 150 call_kwargs["num_stages"] = launcher.config.num_stages 151 args_str = "" 152 args_str += ", ".join(call_args) 153 for k, v in call_kwargs.items(): 154 args_str += f", {k}={v}" 155 156 abs_path = os.path.abspath(sys.argv[0]) 157 with open(f"{abs_path}.launch_params", "a") as f: 158 f.write(f"{kernel_name} | {args_str}\n") 159 160 161class CachingAutotuner(KernelInterface): 162 """ 163 Simplified version of Triton autotuner that has no invalidation 164 key and caches the best config to disk to improve cold start times. 165 Unlike the main triton Autotuner, this version can precompile all 166 configs, and does not rely on the Triton JIT. 167 """ 168 169 def __init__( 170 self, 171 fn, 172 triton_meta, # passed directly to triton 173 configs, 174 save_cache_hook, 175 mutated_arg_names: List[str], # see [Note: clone mutated buffers] 176 heuristic_type, 177 size_hints=None, 178 inductor_meta=None, # metadata not relevant to triton 179 custom_kernel=False, # whether the kernel is inductor-generated or custom 180 filename: Optional[str] = None, 181 ): 182 super().__init__() 183 184 assert len(configs) > 0, "Non-empty TritonConfig list required for compiling" 185 # makes sure there are no pre-hooks on any of the triton configs 186 for cfg in configs: 187 validate_triton_config(cfg) 188 189 self.fn = fn 190 self.device_props: DeviceProperties = triton_meta["device"] 191 self.triton_meta = { 192 **triton_meta, 193 "device": self.device_props.index, 194 "device_type": self.device_props.type, 195 } 196 self.inductor_meta = {} if inductor_meta is None else inductor_meta 197 self.save_cache_hook = save_cache_hook 198 self.mutated_arg_names = mutated_arg_names 199 self.configs = configs 200 self.heuristic_type = heuristic_type 201 self.custom_kernel = custom_kernel 202 self.cuda_kernel_saved = False 203 if log.isEnabledFor(logging.DEBUG): 204 log.debug( 205 "CachingAutotuner gets %d configs for %s", 206 len(self.configs), 207 self.fn.__name__, 208 ) 209 for c in self.configs: 210 log.debug(c) 211 212 self.launchers = [] # type: ignore[var-annotated] 213 self.lock = threading.Lock() 214 if os.getenv("TRITON_CACHE_DIR") is None: 215 os.environ["TRITON_CACHE_DIR"] = os.path.join( 216 cache_dir(), 217 "triton", 218 str(self.triton_meta.get("device", 0)), 219 ) 220 log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"]) 221 222 self.size_hints = size_hints 223 self.coordesc_tuner = CoordescTuner( 224 is_mm=False, 225 name=self.fn.__name__, 226 size_hints=size_hints, 227 inductor_meta=self.inductor_meta, 228 ) 229 self.filename = filename 230 231 self.precompile_time_taken_ns = 0 232 self.autotune_time_taken_ns = 0 233 234 def precompile(self, warm_cache_only=False): 235 with self.lock: 236 if self.launchers: 237 return 238 self.launchers = [] 239 compiled_binaries = [] 240 if not self.configs: 241 raise RuntimeError("No triton configs are available") 242 for c in self.configs: 243 try: 244 compiled_binary, launcher = self._precompile_config( 245 c, warm_cache_only 246 ) 247 except OutOfResources as e: 248 if len(self.configs) == 1: 249 # There are no valid Triton configs 250 raise e 251 # Skip the config if we run out of resource 252 continue 253 self.launchers.append(launcher) 254 compiled_binaries.append(compiled_binary) 255 256 if len(self.launchers) == 0: 257 raise RuntimeError( 258 "No valid triton configs. Report a fatal compilation error" 259 ) 260 261 seen_configs = set(self.configs) 262 263 device_prop = self.device_props 264 if ( 265 self.inductor_meta.get("dynamic_scale_rblock", True) 266 and self.heuristic_type == HeuristicType.REDUCTION 267 and self.size_hints is not None 268 # Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary. 269 and device_prop.type == "cuda" 270 and device_prop.major 271 and device_prop.major >= 8 272 ): 273 assert device_prop.regs_per_multiprocessor 274 assert device_prop.max_threads_per_multi_processor 275 assert device_prop.multi_processor_count 276 for triton_config, compiled_binary in zip( 277 self.configs, compiled_binaries 278 ): 279 assert len(self.size_hints) == 2 280 xblock = triton_config.kwargs.get("XBLOCK", 1) 281 rblock = triton_config.kwargs["RBLOCK"] 282 total_block = (self.size_hints[0] + xblock - 1) // xblock 283 nreg = getattr(compiled_binary, "n_regs", None) 284 if nreg is None: 285 continue 286 287 # make sure rblock is not too small 288 if rblock <= 64: 289 continue 290 291 # each SM of A100 has 65536 32-bit registers. To maximize 292 # the theoretical occupancy, we need run 2048 threads on each 293 # SM. So each thread should use no more than 65536 / 2048 294 # = 32 registers. In cases where occupancy matters, and each 295 # thread uses too many registers, reduce RBLOCK to reduce 296 # the register usage. 297 # For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd 298 # from PLBartForCausalLM, latency improve from 299 # 7.795ms to 4.883ms. 300 # 301 if ( 302 nreg 303 <= device_prop.regs_per_multiprocessor 304 // device_prop.max_threads_per_multi_processor 305 ): 306 continue 307 308 nreg_per_warp = nreg * 32 309 nreg_per_block = nreg_per_warp * triton_config.num_warps 310 311 # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' 312 # The formula below is a tighter upper bound since we have the assumption that 313 # nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor 314 # due to the if condition above and: 315 # regs_per_multiprocessor / nreg_per_block 316 # = regs_per_multiprocessor / (nreg * 32 * num_warps) 317 # < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps) 318 # = max_threads_per_multi_processor / (32 * num_warps) 319 # Using a tigher upper bound can reveal more optimization opportunities. 320 max_blocks_per_sm = max( 321 device_prop.regs_per_multiprocessor // nreg_per_block, 1 322 ) 323 324 if ( 325 total_block 326 <= max_blocks_per_sm * device_prop.multi_processor_count 327 ): 328 # no need to improve occupancy 329 continue 330 new_config = copy.deepcopy(triton_config) 331 new_config.kwargs["RBLOCK"] = rblock // 2 332 if new_config in seen_configs: 333 continue 334 seen_configs.add(new_config) 335 log.debug( 336 "Dynamically scale down RBLOCK from TritonConfig(%s) and get a new TritonConfig(%s)", 337 triton_config, 338 new_config, 339 ) 340 self.launchers.append( 341 self._precompile_config(new_config, warm_cache_only)[1] 342 ) 343 self.configs = None 344 345 def get_device_interface(self): 346 # this code cannot run in compile workers, because it imports from torch 347 from torch._dynamo.device_interface import get_interface_for_device 348 349 return get_interface_for_device(self.device_props.type.replace("hip", "cuda")) 350 351 def _precompile_config(self, cfg: Config, warm_cache_only: bool): 352 """Ahead of time compile a given autotuner config.""" 353 compile_meta = copy.deepcopy(self.triton_meta) 354 for k, v in cfg.kwargs.items(): 355 if self.device_props.type == "hip": 356 if k == "matrix_instr_nonkdim": 357 compile_meta["matrix_instr_nonkdim"] = v 358 continue 359 if k == "waves_per_eu": 360 compile_meta["waves_per_eu"] = v 361 continue 362 compile_meta["constants"][self.fn.arg_names.index(k)] = v 363 compile_meta["num_warps"] = cfg.num_warps 364 compile_meta["num_stages"] = cfg.num_stages 365 compile_meta["debug"] = self.inductor_meta.get( 366 "assert_indirect_indexing", True 367 ) and not self.inductor_meta.get("is_hip", False) 368 369 # device type will be "hip" rather than "cuda" here 370 compile_meta["device_type"] = self.device_props.type 371 compile_meta["cc"] = self.device_props.cc 372 373 if ASTSource: 374 compile_args = ( 375 ASTSource( 376 self.fn, 377 compile_meta["signature"], 378 compile_meta["constants"], 379 compile_meta["configs"][0], 380 ), 381 ) 382 383 cc_str = str(compile_meta["cc"]) 384 if "gfx10" in cc_str or "gfx11" in cc_str: 385 rocm_warp_size = 32 386 else: 387 rocm_warp_size = 64 388 389 if GPUTarget: 390 target = GPUTarget( 391 compile_meta["device_type"], 392 compile_meta["cc"], 393 rocm_warp_size if torch.version.hip else 32, 394 ) 395 else: 396 target = ( 397 (compile_meta["device_type"], compile_meta["cc"]) 398 if not torch.version.hip 399 else [ 400 compile_meta["device_type"], 401 compile_meta["cc"], 402 rocm_warp_size, 403 ] 404 ) 405 406 options = { 407 "num_warps": compile_meta["num_warps"], 408 "num_stages": compile_meta["num_stages"], 409 "debug": compile_meta["debug"], 410 } 411 if self.device_props.type == "hip": 412 if "waves_per_eu" in compile_meta: 413 options["waves_per_eu"] = compile_meta["waves_per_eu"] 414 if "matrix_instr_nonkdim" in compile_meta: 415 options["matrix_instr_nonkdim"] = compile_meta[ 416 "matrix_instr_nonkdim" 417 ] 418 compile_kwargs = { 419 "target": target, 420 "options": options, 421 } 422 else: 423 compile_args = (self.fn,) 424 compile_kwargs = compile_meta 425 426 if warm_cache_only: 427 return ( 428 triton.compile(*compile_args, **compile_kwargs), 429 None, 430 ) 431 432 # importing from torch is safe now that precompile has returned 433 from torch._dynamo.device_interface import DeviceGuard 434 435 device_interface = self.get_device_interface() 436 437 # load binary to the correct device 438 with DeviceGuard(device_interface, compile_meta["device"]): # type: ignore[attr-defined] 439 # need to initialize context 440 device_interface.synchronize(device_interface.current_device()) 441 442 try: 443 binary = triton.compile(*compile_args, **compile_kwargs) 444 except Exception: 445 log.exception( 446 "Triton compilation failed: %s\n%s\nmetadata: %s", 447 self.inductor_meta.get("kernel_name", "triton_"), 448 self.fn.src, 449 compile_meta, 450 ) 451 raise 452 binary._init_handles() 453 454 call_args = [ 455 arg 456 for i, arg in enumerate(self.fn.arg_names) 457 if i not in self.fn.constexprs 458 ] 459 def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs] 460 461 binary_shared = ( 462 binary.shared if hasattr(binary, "shared") else binary.metadata.shared 463 ) 464 465 scope = { 466 "grid_meta": cfg.kwargs, 467 "bin": binary, 468 "launch_enter_hook": CompiledKernel.launch_enter_hook, 469 "launch_exit_hook": CompiledKernel.launch_exit_hook, 470 "metadata": binary.packed_metadata 471 if hasattr(binary, "packed_metadata") 472 else binary.metadata, 473 "shared": binary_shared, 474 } 475 476 scope["num_warps"] = ( 477 binary.num_warps 478 if hasattr(binary, "num_warps") 479 else binary.metadata.num_warps 480 ) 481 482 scope["cta_args"] = ( 483 (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) 484 if hasattr(binary, "num_ctas") 485 else ( 486 (binary.metadata.num_ctas, *binary.metadata.cluster_dims) 487 if hasattr(binary, "metadata") 488 else () 489 ) 490 ) 491 492 scope["function"] = get_first_attr(binary, "function", "cu_function") 493 494 def get_launch_args_without_kernel_launch_metadata( 495 grid, 496 grid_0, 497 grid_1, 498 grid_2, 499 stream, 500 function, 501 metadata, 502 bin, 503 launch_enter_hook, 504 launch_exit_hook, 505 num_warps, 506 shared, 507 cta_args, 508 args, 509 ): 510 """ 511 Construct launch args before CompiledKernel.launch_metadata is added. 512 """ 513 return ( 514 grid_0, 515 grid_1, 516 grid_2, 517 num_warps, 518 *cta_args, 519 shared, 520 stream, 521 function, 522 launch_enter_hook, 523 launch_exit_hook, 524 metadata, 525 ) 526 527 # Getting the kernel launch args is extremely perf-sensitive. Evaluating 528 # `bin.launch_metadata` is relatively expensive, and returns None unless a 529 # `launch_enter_hook` is installed. So if we don't have that hook installed, 530 # we want to burn None in to the launch args with zero overhead. 531 # See https://github.com/pytorch/pytorch/issues/123597 532 if binary.launch_enter_hook: 533 534 def get_launch_args_with_kernel_launch_metadata( 535 grid, 536 grid_0, 537 grid_1, 538 grid_2, 539 stream, 540 function, 541 metadata, 542 bin, 543 launch_enter_hook, 544 launch_exit_hook, 545 num_warps, 546 shared, 547 cta_args, 548 args, 549 ): 550 """ 551 Construct launch args after CompiledKernel.launch_metadata is added 552 by https://github.com/openai/triton/pull/3492 . 553 """ 554 return ( 555 grid_0, 556 grid_1, 557 grid_2, 558 stream, 559 function, 560 metadata, 561 bin.launch_metadata(grid, stream, *args), 562 launch_enter_hook, 563 launch_exit_hook, 564 ) 565 566 else: 567 568 def get_launch_args_with_kernel_launch_metadata( 569 grid, 570 grid_0, 571 grid_1, 572 grid_2, 573 stream, 574 function, 575 metadata, 576 bin, 577 launch_enter_hook, 578 launch_exit_hook, 579 num_warps, 580 shared, 581 cta_args, 582 args, 583 ): 584 """ 585 Construct launch args after CompiledKernel.launch_metadata is added 586 by https://github.com/openai/triton/pull/3492 . 587 """ 588 return ( 589 grid_0, 590 grid_1, 591 grid_2, 592 stream, 593 function, 594 metadata, 595 None, 596 launch_enter_hook, 597 launch_exit_hook, 598 ) 599 600 scope["get_launch_args"] = ( 601 get_launch_args_with_kernel_launch_metadata 602 if hasattr(binary, "launch_metadata") 603 else get_launch_args_without_kernel_launch_metadata 604 ) 605 606 scope["runner"] = get_first_attr(binary, "run", "c_wrapper") 607 608 exec( 609 f""" 610 def launcher({', '.join(def_args)}, grid, stream): 611 if callable(grid): 612 grid_0, grid_1, grid_2 = grid(grid_meta) 613 else: 614 grid_0, grid_1, grid_2 = grid 615 616 args = {', '.join(call_args)}, 617 launch_args = get_launch_args( 618 grid, grid_0, grid_1, grid_2, stream, function, 619 metadata, bin, launch_enter_hook, launch_exit_hook, 620 num_warps, shared, cta_args, args 621 ) 622 runner(*launch_args, *args) 623 return bin 624 """.lstrip(), 625 scope, 626 ) 627 628 launcher = scope["launcher"] 629 launcher.config = cfg 630 launcher.n_regs = getattr(binary, "n_regs", None) 631 launcher.n_spills = getattr(binary, "n_spills", None) 632 launcher.shared = binary_shared 633 launcher.store_cubin = self.inductor_meta.get("store_cubin", False) 634 # store this global variable to avoid the high overhead of reading it when calling run 635 if launcher.store_cubin: 636 launcher.fn = self.fn 637 launcher.bin = binary 638 639 return binary, launcher 640 641 def bench(self, launcher, *args, grid, with_profiler=False, **kwargs): 642 """Measure the performance of a given launcher""" 643 # we don't skip configs wiht spilled registers when auto-tuning custom 644 # (user-written) Triton kernels, as (i) we don't have any knowledge or 645 # control over the kernel code; (ii) there is empirical evidence that 646 # for some (complicated) custom Triton kernels, a register-spilling 647 # config may yield the best latency. 648 if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( 649 "spill_threshold", 16 650 ): 651 log.debug( 652 "Skip config %s because of register spilling: %d", 653 launcher.config, 654 launcher.n_spills, 655 ) 656 return float("inf") 657 658 device_interface = self.get_device_interface() 659 stream = device_interface.get_raw_stream(device_interface.current_device()) 660 661 def kernel_call(): 662 cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) 663 launcher( 664 *cloned_args, 665 **cloned_kwargs, 666 grid=grid, 667 stream=stream, 668 ) 669 670 if with_profiler: 671 from torch._inductor.utils import do_bench_using_profiling 672 673 return do_bench_using_profiling(kernel_call, warmup=10, rep=40) 674 675 return benchmarker.benchmark_gpu(kernel_call, rep=40, fast_flush=True) 676 677 def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: 678 from ..compile_fx import clone_preserve_strides 679 680 # [Note: clone mutated buffers] 681 # clone inplace buffers to avoid autotune contaminating them if 682 # the kernel does in-place stores. avoid cloning other buffers because 683 # it leads to increase memory use 684 cloned_args = [] 685 for i, arg in enumerate(args): 686 if self.fn.arg_names[i] in self.mutated_arg_names: 687 assert isinstance(arg, torch.Tensor) 688 cloned_args.append(clone_preserve_strides(arg)) 689 else: 690 cloned_args.append(arg) 691 692 cloned_kwargs: Dict[str, Any] = {} 693 for name, arg in kwargs.items(): 694 if name in self.mutated_arg_names: 695 assert isinstance(arg, torch.Tensor) 696 cloned_kwargs[name] = clone_preserve_strides(arg) 697 else: 698 cloned_kwargs[name] = arg 699 700 return cloned_args, cloned_kwargs 701 702 def benchmark_all_configs(self, *args, **kwargs): 703 with dynamo_timed("CachingAutotuner.benchmark_all_configs"): 704 timings = { 705 launcher: self.bench(launcher, *args, **kwargs) 706 for launcher in self.launchers 707 } 708 709 for k, v in timings.items(): 710 self.coordesc_tuner.cache_benchmark_result(k.config, v) 711 712 if log.isEnabledFor(logging.DEBUG): 713 log.debug("Benchmark all input configs for %s, get:", self.fn.__name__) 714 for k, v in timings.items(): 715 log.debug( 716 "%s: %f, nreg %d, nspill %d, #shared-mem %s", 717 k.config, 718 v, 719 k.n_regs, 720 k.n_spills, 721 k.shared, 722 ) 723 724 return timings 725 726 def autotune_to_one_config(self, *args, **kwargs): 727 """Do the actual autotuning""" 728 start_time = time.time_ns() 729 timings = self.benchmark_all_configs(*args, **kwargs) 730 benchmark_time_taken_ns = time.time_ns() - start_time 731 self.launchers = [builtins.min(timings, key=timings.get)] 732 self.autotune_time_taken_ns = ( 733 self.precompile_time_taken_ns + benchmark_time_taken_ns 734 ) 735 if self.save_cache_hook: 736 self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns) 737 738 def save_gpu_kernel(self, grid, stream, launcher): 739 if callable(grid): 740 grid_x, grid_y, grid_z = grid(launcher.config.kwargs) 741 else: 742 grid_x, grid_y, grid_z = grid 743 744 key = self.inductor_meta.get("kernel_name", None) # unique kernel name 745 assert key is not None, "kernel_name can not be None" 746 params = { 747 "mangled_name": launcher.bin.metadata.name 748 if hasattr(launcher.bin.metadata, "name") 749 else launcher.bin.metadata["name"], 750 "grid_x": grid_x, 751 "grid_y": grid_y, 752 "grid_z": grid_z, 753 "x_block": launcher.config.kwargs.get("XBLOCK", 1), 754 "y_block": launcher.config.kwargs.get("YBLOCK", None), 755 "z_block": launcher.config.kwargs.get("ZBLOCK", None), 756 "num_warps": launcher.bin.num_warps 757 if hasattr(launcher.bin, "num_warps") 758 else launcher.bin.metadata.num_warps, 759 "shared_mem": launcher.bin.shared 760 if hasattr(launcher.bin, "shared") 761 else launcher.bin.metadata.shared, 762 "stream": stream, 763 # User defined triton kernels will have arbitrary kwarg names 764 "meta": launcher.config.kwargs, 765 } 766 from torch._inductor.codecache import CudaKernelParamCache 767 768 bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin") 769 binary = launcher.bin.asm[bin_type] 770 CudaKernelParamCache.set(key, params, binary, bin_type) 771 772 self.cuda_kernel_saved = True 773 774 def coordinate_descent_tuning(self, launcher, *args, **kwargs): 775 """ 776 Coordinate descent tuning can be run with or without max-autotune. 777 778 The only difference between these two is the starting config for coordinate_descent tuning. 779 E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4 780 and max-autotune figure out C3 is the best. 781 782 Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1; 783 while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. 784 """ 785 if ( 786 self.heuristic_type == HeuristicType.TEMPLATE 787 or self.heuristic_type == HeuristicType.USER_AUTOTUNE 788 ): 789 # skip triton template 790 return launcher 791 792 config2launcher = {launcher.config: launcher} 793 794 def benchmark_one_config(config): 795 with self.lock: 796 _, launcher = self._precompile_config(config, False) 797 config2launcher[config] = launcher 798 799 out = self.bench(launcher, *args, **kwargs) 800 log.debug( 801 "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d", 802 launcher.config, 803 out, 804 launcher.n_regs, 805 launcher.n_spills, 806 launcher.shared, 807 ) 808 return out 809 810 assert not ( 811 self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION 812 and "RBLOCK" in launcher.config.kwargs 813 ), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK" 814 start_time = time.time_ns() 815 best_config = self.coordesc_tuner.autotune( 816 benchmark_one_config, launcher.config, None 817 ) 818 coordesc_time_taken_ns = time.time_ns() - start_time 819 best_config.found_by_coordesc = True 820 821 if self.save_cache_hook: 822 self.save_cache_hook( 823 best_config, 824 self.autotune_time_taken_ns + coordesc_time_taken_ns, 825 found_by_coordesc=True, 826 ) 827 return config2launcher.get(best_config) 828 829 def run(self, *args, grid, stream, **kwargs): 830 if len(self.launchers) != 1: 831 if len(self.launchers) == 0: 832 start_time = time.time_ns() 833 self.precompile() 834 self.precompile_time_taken_ns = time.time_ns() - start_time 835 if len(self.launchers) > 1: 836 self.autotune_to_one_config(*args, grid=grid, **kwargs) 837 838 if not getattr( 839 self.launchers[0].config, "found_by_coordesc", False 840 ) and self.inductor_meta.get("coordinate_descent_tuning", False): 841 self.launchers = [ 842 self.coordinate_descent_tuning( 843 self.launchers[0], *args, grid=grid, **kwargs 844 ) 845 ] 846 847 (launcher,) = self.launchers 848 if launcher.store_cubin: 849 self.save_gpu_kernel(grid, stream, launcher) 850 851 if os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", 0) == "1": 852 _dump_launch_params(args, kwargs, launcher, self.fn.__name__) 853 854 # it is faster than entering and exiting a context manager, even if the context 855 # manager is a nullcontext. 856 if autograd_profiler._is_profiler_enabled: 857 # grid can be a tuple of ints or a string. 858 if isinstance(grid, tuple): 859 grid_info = str(grid) 860 else: 861 grid_info = getattr(grid, "grid_fn_str", "") 862 with torch._C._profiler._RecordFunctionFast( 863 self.inductor_meta.get("kernel_name", "triton kernel"), 864 args, 865 { 866 "kernel_file": "" if self.filename is None else self.filename, 867 "kernel_backend": "triton", 868 "grid": grid_info, 869 "stream": stream, 870 }, 871 ): 872 return launcher( 873 *args, 874 **kwargs, 875 grid=grid, 876 stream=stream, 877 ) 878 else: 879 return launcher( 880 *args, 881 **kwargs, 882 grid=grid, 883 stream=stream, 884 ) 885 886 887def _find_names(obj): 888 import gc 889 import inspect 890 891 frame = inspect.currentframe() 892 while frame is not None: 893 frame.f_locals 894 frame = frame.f_back 895 obj_names = [] 896 for referrer in gc.get_referrers(obj): 897 if isinstance(referrer, dict): 898 for k, v in referrer.items(): 899 if v is obj: 900 obj_names.append(k) 901 return obj_names 902 903 904collected_calls: List[Any] = [] 905 906 907def start_graph(): 908 collected_calls.clear() 909 910 911def end_graph(output_file): 912 if len(collected_calls) == 0: 913 return 914 overall_time = sum(call[0] for call in collected_calls) 915 overall_gb = sum(call[1] for call in collected_calls) 916 cur_file = inspect.stack()[1].filename 917 summary_str = ( 918 f"SUMMARY ({cur_file})\n" 919 f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s" 920 ) 921 print(summary_str) 922 print() 923 if output_file is not None: 924 # sort perf numbers in descending order, i.e. placing the 925 # most runtime-heavy kernels at the top of the list 926 sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True) 927 try: 928 with open(output_file, "a") as file: 929 log.debug("Save profile bandwidth results to %s", output_file) 930 file.write("====================\n") 931 file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n") 932 for ms, num_gb, gb_per_s, kernel_name in sorted_calls: 933 # also display the runtime percentage for each kernel 934 percentage = f"{ms/overall_time*100:.2f}%" 935 suffix = f" \t {percentage} \t {kernel_name}" 936 bw_info_str = create_bandwidth_info_str( 937 ms, 938 num_gb, 939 gb_per_s, 940 suffix=suffix, 941 color=False, 942 ) 943 file.write(bw_info_str + "\n") 944 file.write(f"{summary_str}\n\n") 945 except Exception as e: 946 log.warning( 947 "failed to write profile bandwidth result into %s: %s", 948 output_file, 949 e, 950 ) 951 952 953class DebugAutotuner(CachingAutotuner): 954 def __init__(self, *args, regex_filter="", with_profiler=False, **kwargs): 955 self.regex_filter = regex_filter 956 self.with_profiler = with_profiler 957 super().__init__(*args, **kwargs) 958 self.cached = None 959 960 def run(self, *args, grid, stream): 961 possible_names = _find_names(self) 962 kernel_name = f"{max(possible_names, key=len)}" 963 if not re.match(self.regex_filter, kernel_name): 964 return 965 super().run(*args, grid=grid, stream=stream) 966 (launcher,) = self.launchers 967 968 if self.cached is None: 969 ms = self.bench( 970 launcher, *args, grid=grid, with_profiler=self.with_profiler 971 ) 972 num_in_out_ptrs = len( 973 [ 974 arg_name 975 for arg_name in self.fn.arg_names 976 if arg_name.startswith("in_out_ptr") 977 ] 978 ) 979 num_gb = self.inductor_meta.get("kernel_num_gb", None) 980 if num_gb is None: 981 num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 982 gb_per_s = num_gb / (ms / 1e3) 983 self.cached = ms, num_gb, gb_per_s, kernel_name 984 collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) 985 print( 986 create_bandwidth_info_str( 987 ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}" 988 ) 989 ) 990 991 992def hash_configs(configs: List[Config]): 993 """ 994 Hash used to check for changes in configurations 995 """ 996 hasher = hashlib.sha256() 997 for cfg in configs: 998 hasher.update( 999 f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode() 1000 ) 1001 return hasher.hexdigest() 1002 1003 1004def cached_autotune( 1005 size_hints: Optional[List[int]], 1006 configs: List[Config], 1007 triton_meta, 1008 heuristic_type, 1009 filename=None, 1010 inductor_meta=None, 1011 custom_kernel=False, 1012): 1013 """ 1014 A copy of triton.autotune that calls our subclass. Our subclass 1015 has additional debugging, error handling, and on-disk caching. 1016 """ 1017 configs = unique_configs(configs) 1018 assert len(configs) == 1 or filename 1019 inductor_meta = {} if inductor_meta is None else inductor_meta 1020 1021 disabled = inductor_meta.get("force_disable_caches", False) 1022 1023 # on disk caching logic and/or remote caching 1024 autotune_cache = None 1025 if ( 1026 not disabled 1027 and filename is not None 1028 and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) 1029 ): 1030 configs_hash = hash_configs(configs) 1031 1032 autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) 1033 if autotune_cache: 1034 if best_config := autotune_cache.read_best(inductor_meta, configs): 1035 configs = [best_config] 1036 1037 else: 1038 if disabled: 1039 log.debug("autotune caching is disabled by config.force_disable_caches") 1040 1041 mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) 1042 1043 def decorator(fn): 1044 # Remove XBLOCK from config if it's not a function argument. 1045 # This way, coordinate descent tuning will not try to tune it. 1046 # 1047 # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1. 1048 import inspect 1049 1050 if "XBLOCK" not in inspect.signature(fn.fn).parameters: 1051 for tconfig in configs: 1052 if "XBLOCK" in tconfig.kwargs: 1053 assert tconfig.kwargs["XBLOCK"] == 1 1054 tconfig.kwargs.pop("XBLOCK") 1055 1056 if inductor_meta.get("profile_bandwidth"): 1057 return DebugAutotuner( 1058 fn, 1059 triton_meta=triton_meta, 1060 inductor_meta=inductor_meta, 1061 regex_filter=inductor_meta["profile_bandwidth_regex"], 1062 with_profiler=inductor_meta[ 1063 "profile_bandwidth_with_do_bench_using_profiling" 1064 ], 1065 configs=configs, 1066 save_cache_hook=autotune_cache and autotune_cache.save, 1067 mutated_arg_names=mutated_arg_names, 1068 heuristic_type=heuristic_type, 1069 size_hints=size_hints, 1070 custom_kernel=custom_kernel, 1071 filename=filename, 1072 ) 1073 return CachingAutotuner( 1074 fn, 1075 triton_meta=triton_meta, 1076 inductor_meta=inductor_meta, 1077 configs=configs, 1078 save_cache_hook=autotune_cache and autotune_cache.save, 1079 mutated_arg_names=mutated_arg_names, 1080 heuristic_type=heuristic_type, 1081 size_hints=size_hints, 1082 custom_kernel=custom_kernel, 1083 filename=filename, 1084 ) 1085 1086 return decorator 1087 1088 1089def unique_configs(configs: List[Config]): 1090 """Remove duplicate configurations""" 1091 seen = set() 1092 pruned_configs = [] 1093 1094 for cfg in configs: 1095 key = triton_config_to_hashable(cfg) 1096 if key not in seen: 1097 seen.add(key) 1098 pruned_configs.append(cfg) 1099 return pruned_configs 1100 1101 1102def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None): 1103 for numel, label in zip((xnumel, ynumel, znumel), "XYZ"): 1104 if numel is None: 1105 continue 1106 block = cfg[f"{label}BLOCK"] 1107 if numel == 1: 1108 assert block == 1, ( 1109 f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1" 1110 f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})." 1111 ) 1112 max_block = TRITON_MAX_BLOCK[label] 1113 max_block_str = f'config.triton.max_block["{label}"]' 1114 assert max_block % block == 0, ( 1115 f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}" 1116 f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})." 1117 ) 1118 1119 1120def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False): 1121 # On AMD GPU each warp has 64 lanes which is double the size on NV GPU, 1122 # therefore using half the number of warps here correspondingly. 1123 if torch.version.hip: 1124 max_num_warps = (max_num_warps + 1) // 2 1125 min_num_warps = (min_num_warps + 1) // 2 1126 # persistent reduction is register intensive 1127 if register_intensive: 1128 max_num_warps = max_num_warps // 2 1129 return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps)) 1130 1131 1132def _check_max_grid_x(size_hints, x, num_warps): 1133 # Check if maxGridSize is exceeded - if so then must scale XBLOCK further 1134 max_grid_x = 2147483647 1135 warp_size = ( 1136 64 if torch.version.hip else 32 1137 ) # TODO: query warp size once #129663 is merged 1138 num_blocks = (size_hints[0] + x - 1) // x 1139 1140 while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints[0]: 1141 x *= 2 # Scale up XBLOCK if grid exceeds limits 1142 num_blocks = num_blocks // 2 1143 if x >= max_grid_x: 1144 raise AssertionError( 1145 "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue" 1146 ) 1147 return x, num_blocks 1148 1149 1150def triton_config( 1151 size_hints, 1152 x, 1153 y=None, 1154 z=None, 1155 num_stages=1, 1156 num_elements_per_warp=256, 1157 min_elem_per_thread=0, 1158) -> Config: 1159 """ 1160 Construct a pointwise triton config with some adjustment heuristics 1161 based on size_hints. Size_hints is a tuple of numels in each tile 1162 dimension and will be rounded up to the nearest power of 2. 1163 1164 num_elements_per_warp is a suggestion for controlling how many warps 1165 the triton config should contain. e.g.: if x=16, y=8, z=4 then 1166 num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128, 1167 we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's 1168 just a suggestion, and sometimes other adjustment heuristics will 1169 override the num_elements_per_warp. 1170 1171 min_elem_per_thread controls the minimum number of elements 1172 processed by each thread. It's always enforced. 1173 """ 1174 # Ideally we want to read this from some device config 1175 1176 # for a 2d size_hints [a, b], a should be mapped to YBLOCK rather than XBLOCK 1177 size_hints = list(reversed(size_hints)) 1178 1179 maxGridSize = [2147483647, 65535, 65535] 1180 1181 target = conditional_product(x, y, z) 1182 if conditional_product(*size_hints) < target: 1183 target //= 8 1184 1185 # shrink sizes to size hints 1186 x = min(x, size_hints[0]) 1187 if y: 1188 y = min(y, size_hints[1]) 1189 if z: 1190 z = min(z, size_hints[2]) 1191 1192 # if we are below original block size, scale up where we can; 1193 # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension 1194 while x < min(size_hints[0], TRITON_MAX_BLOCK["X"]) and ( 1195 x * maxGridSize[0] < size_hints[0] or conditional_product(x, y, z) < target 1196 ): 1197 x *= 2 1198 while ( 1199 y 1200 and y < min(size_hints[1], TRITON_MAX_BLOCK["Y"]) 1201 and ( 1202 y * maxGridSize[1] < size_hints[1] or conditional_product(x, y, z) < target 1203 ) 1204 ): 1205 y *= 2 1206 while ( 1207 z 1208 and z < min(size_hints[2], TRITON_MAX_BLOCK["Z"]) 1209 and ( 1210 z * maxGridSize[2] < size_hints[2] or conditional_product(x, y, z) < target 1211 ) 1212 ): 1213 z *= 2 1214 1215 num_warps = _num_warps( 1216 conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 1217 ) 1218 # we are going to arrive at 2 warps only if bs was too small due to 1219 # numel being too small. However to workaround some ptx bugs we still 1220 # want at least 4 warps if there's enough elements per thread 1221 # given that this is a rare situation, don't expect this to affect perf 1222 # in general 1223 # see https://github.com/pytorch/pytorch/pull/97950 1224 if conditional_product(x, y, z) >= 128 and not torch.version.hip: 1225 num_warps = max(num_warps, 4) 1226 xnumel = size_hints[0] 1227 ynumel = size_hints[1] if y else None 1228 znumel = size_hints[2] if z else None 1229 1230 # Increase x to satisfy min_elem_per_thread requirements. 1231 block_size = max( 1232 conditional_product(x, y, z), 1233 min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps, 1234 ) 1235 x *= math.ceil(block_size / conditional_product(x, y, z)) 1236 1237 x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) 1238 1239 cfg = {"XBLOCK": x} 1240 if y: 1241 cfg["YBLOCK"] = y 1242 if z: 1243 cfg["ZBLOCK"] = z 1244 assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}" 1245 check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) 1246 return Config(cfg, num_warps=num_warps, num_stages=num_stages) 1247 1248 1249def triton_config_reduction( 1250 size_hints, x, r, num_stages=1, num_warps=None, register_intensive=False 1251) -> Config: 1252 """ 1253 Construct a reduction triton config with some adjustment heuristics 1254 based on size_hints. Size_hints is a tuple of numels in each tile 1255 dimension and will be rounded up to the nearest power of 2. 1256 """ 1257 1258 target = conditional_product(x, r) 1259 if conditional_product(*size_hints) < target: 1260 target //= 8 1261 1262 # shrink sizes to size hints 1263 x = min(x, size_hints[0]) 1264 r = min(r, size_hints[1]) 1265 1266 # if we are below original block size, scale up where we can 1267 while x < size_hints[0] and conditional_product(x, r) < target: 1268 x *= 2 1269 while r < size_hints[1] and conditional_product(x, r) < target: 1270 r *= 2 1271 1272 if num_warps is None: 1273 num_warps = conditional_product(x, r) // 128 1274 num_warps = _num_warps( 1275 num_warps, max_num_warps=16, register_intensive=register_intensive 1276 ) 1277 1278 x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) 1279 1280 while conditional_product(x, r) > target: 1281 if r == 1: 1282 break 1283 r = r // 2 1284 1285 cfg = {"XBLOCK": x, "RBLOCK": r} 1286 check_config(cfg, xnumel=size_hints[0]) 1287 assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}" 1288 assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}" 1289 return Config(cfg, num_warps=num_warps, num_stages=num_stages) 1290 1291 1292def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1): 1293 """ 1294 Construct a tile reduction triton config with some adjustment 1295 heuristics based on size_hints. Size_hints is a tuple of numels in 1296 each tile dimension and will be rounded up to the nearest power of 2. 1297 """ 1298 1299 target = conditional_product(x, y, r) 1300 if conditional_product(*size_hints) < target: 1301 target //= 8 1302 1303 # shrink sizes to size hints 1304 x = min(x, size_hints[0]) 1305 y = min(y, size_hints[1]) 1306 r = min(r, size_hints[2]) 1307 1308 # if we are below original block size, scale up where we can 1309 while x < size_hints[0] and conditional_product(x, y, r) < target: 1310 x *= 2 1311 while r < size_hints[2] and conditional_product(x, y, r) < target: 1312 r *= 2 1313 while y < size_hints[1] and conditional_product(x, y, r) < target: 1314 y *= 2 1315 1316 cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r} 1317 num_warps = _num_warps(conditional_product(x, y, r) // 256, min_num_warps=1) 1318 check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1]) 1319 assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}" 1320 return Config(cfg, num_warps=num_warps, num_stages=num_stages) 1321 1322 1323def pointwise( 1324 size_hints, 1325 triton_meta, 1326 tile_hint=None, 1327 filename=None, 1328 min_elem_per_thread=0, 1329 inductor_meta=None, 1330): 1331 """ 1332 Construct @triton.heuristics() based on size_hints. 1333 """ 1334 inductor_meta = {} if inductor_meta is None else inductor_meta 1335 assert not inductor_meta.get("no_x_dim") 1336 1337 numel = functools.reduce(operator.mul, size_hints) 1338 bs = max(256, min(numel // 128, 1024)) 1339 1340 hinted_configs = autotune_hints_to_configs( 1341 inductor_meta.get("autotune_hints", set()), size_hints, bs 1342 ) 1343 1344 triton_config_with_settings = functools.partial( 1345 triton_config, min_elem_per_thread=min_elem_per_thread 1346 ) 1347 1348 if len(size_hints) == 1: 1349 if disable_pointwise_autotuning(inductor_meta) and not ( 1350 inductor_meta.get("max_autotune") 1351 or inductor_meta.get("max_autotune_pointwise") 1352 ): 1353 return cached_autotune( 1354 size_hints, 1355 [triton_config_with_settings(size_hints, bs)], 1356 triton_meta=triton_meta, 1357 inductor_meta=inductor_meta, 1358 heuristic_type=HeuristicType.POINTWISE, 1359 filename=filename, 1360 ) 1361 else: 1362 return cached_autotune( 1363 size_hints, 1364 [ 1365 triton_config_with_settings( 1366 size_hints, bs, num_elements_per_warp=256 1367 ), 1368 triton_config_with_settings( 1369 size_hints, bs // 2, num_elements_per_warp=64 1370 ), 1371 *hinted_configs, 1372 ], 1373 triton_meta=triton_meta, 1374 inductor_meta=inductor_meta, 1375 heuristic_type=HeuristicType.POINTWISE, 1376 filename=filename, 1377 ) 1378 if len(size_hints) == 2: 1379 if ( 1380 disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE 1381 ) and not ( 1382 inductor_meta.get("max_autotune") 1383 or inductor_meta.get("max_autotune_pointwise") 1384 ): 1385 return cached_autotune( 1386 size_hints, 1387 [triton_config_with_settings(size_hints, 32, 32)], 1388 triton_meta=triton_meta, 1389 inductor_meta=inductor_meta, 1390 heuristic_type=HeuristicType.POINTWISE, 1391 filename=filename, 1392 ) 1393 return cached_autotune( 1394 size_hints, 1395 [ 1396 triton_config_with_settings(size_hints, 32, 32), 1397 triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 1398 triton_config_with_settings(size_hints, 256, 16), 1399 triton_config_with_settings(size_hints, 16, 256), 1400 triton_config_with_settings(size_hints, bs, 1), 1401 triton_config_with_settings(size_hints, 1, bs), 1402 *hinted_configs, 1403 ], 1404 triton_meta=triton_meta, 1405 inductor_meta=inductor_meta, 1406 filename=filename, 1407 heuristic_type=HeuristicType.POINTWISE, 1408 ) 1409 if len(size_hints) == 3: 1410 if disable_pointwise_autotuning(inductor_meta): 1411 return cached_autotune( 1412 size_hints, 1413 [triton_config_with_settings(size_hints, 16, 16, 16)], 1414 triton_meta=triton_meta, 1415 inductor_meta=inductor_meta, 1416 heuristic_type=HeuristicType.POINTWISE, 1417 filename=filename, 1418 ) 1419 return cached_autotune( 1420 size_hints, 1421 [ 1422 triton_config_with_settings(size_hints, 16, 16, 16), 1423 triton_config_with_settings(size_hints, 64, 8, 8), 1424 triton_config_with_settings(size_hints, 8, 64, 8), 1425 triton_config_with_settings(size_hints, 8, 8, 64), 1426 triton_config_with_settings(size_hints, bs, 1, 1), 1427 triton_config_with_settings(size_hints, 1, bs, 1), 1428 triton_config_with_settings(size_hints, 1, 1, bs), 1429 *hinted_configs, 1430 ], 1431 triton_meta=triton_meta, 1432 inductor_meta=inductor_meta, 1433 filename=filename, 1434 heuristic_type=HeuristicType.POINTWISE, 1435 ) 1436 raise NotImplementedError(f"size_hints: {size_hints}") 1437 1438 1439def _reduction_configs( 1440 *, size_hints: List[int], inductor_meta: Dict[str, Any] 1441) -> List[Config]: 1442 reduction_hint = inductor_meta.get("reduction_hint", None) 1443 assert len(size_hints) == 2 1444 rnumel = size_hints[-1] 1445 1446 register_intensive = False 1447 MAX_RBLOCK = 2048 1448 if ( 1449 size_hints[0] >= 1024 1450 and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0) 1451 >= 10 1452 ): 1453 # A heuristics to reduce RBLOCK if a kernel potentially need many registers. 1454 # Consider load and reduction since load need move data into registers and 1455 # reduction needs an accumulator. 1456 # 1457 # The magic numbers are a bit arbitrary. 1458 # 1459 # We cannot rely on dynamically scaling down RBLOCK later, since sometimes 1460 # triton makes it to use less registers with worse perf. Check: 1461 # https://github.com/pytorch/pytorch/issues/126463 1462 # 1463 # The heuristic is a very simple one since registers can be reused. But 1464 # hopefully it can be a good enough indicator. 1465 MAX_RBLOCK = 1024 1466 register_intensive = True 1467 1468 contiguous_config = triton_config_reduction( 1469 size_hints, 1470 1, 1471 (rnumel if 256 <= rnumel < MAX_RBLOCK else MAX_RBLOCK), 1472 register_intensive=register_intensive, 1473 ) 1474 outer_config = triton_config_reduction( 1475 size_hints, 64, 8, register_intensive=register_intensive 1476 ) 1477 tiny_config = triton_config_reduction( 1478 size_hints, 1479 2 * (256 // rnumel) if rnumel <= 256 else 1, 1480 min(rnumel, MAX_RBLOCK), 1481 register_intensive=register_intensive, 1482 ) 1483 if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"): 1484 pass # skip all these cases 1485 elif reduction_hint == ReductionHint.INNER: 1486 return [contiguous_config] 1487 elif reduction_hint == ReductionHint.OUTER: 1488 return [outer_config] 1489 elif reduction_hint == ReductionHint.OUTER_TINY: 1490 return [tiny_config] 1491 if disable_pointwise_autotuning(inductor_meta): 1492 return [triton_config_reduction(size_hints, 32, 128)] 1493 return [ 1494 contiguous_config, 1495 outer_config, 1496 tiny_config, 1497 triton_config_reduction(size_hints, 64, 64), 1498 triton_config_reduction(size_hints, 8, 512), 1499 # halve the XBLOCK/RBLOCK compared to outer_config 1500 # TODO: this may only be beneficial when each iteration of the reduction 1501 # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 1502 triton_config_reduction(size_hints, 64, 4, num_warps=8), 1503 ] 1504 1505 1506def reduction( 1507 size_hints, 1508 reduction_hint=False, 1509 triton_meta=None, 1510 filename=None, 1511 inductor_meta=None, 1512): 1513 """args to @triton.heuristics()""" 1514 inductor_meta = {} if inductor_meta is None else inductor_meta 1515 inductor_meta["reduction_hint"] = reduction_hint 1516 if inductor_meta.get("no_x_dim"): 1517 size_hints = [1, *size_hints[1:]] 1518 1519 assert triton_meta is not None 1520 rnumel = size_hints[-1] 1521 if len(size_hints) != 2: 1522 raise NotImplementedError(f"size_hints: {size_hints}") 1523 1524 configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) 1525 return cached_autotune( 1526 size_hints, 1527 configs=configs, 1528 triton_meta=triton_meta, 1529 inductor_meta=inductor_meta, 1530 heuristic_type=HeuristicType.REDUCTION, 1531 filename=filename, 1532 ) 1533 1534 1535def persistent_reduction( 1536 size_hints, 1537 reduction_hint=False, 1538 triton_meta=None, 1539 filename=None, 1540 inductor_meta=None, 1541): 1542 inductor_meta = {} if inductor_meta is None else inductor_meta 1543 inductor_meta["reduction_hint"] = reduction_hint 1544 if inductor_meta.get("no_x_dim"): 1545 size_hints = [1, *size_hints[1:]] 1546 1547 xnumel, rnumel = size_hints 1548 1549 configs = [ 1550 triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) 1551 for xblock in (1, 8, 32, 128) 1552 if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel) 1553 ] 1554 1555 # TODO(jansel): we should be able to improve these heuristics 1556 if reduction_hint == ReductionHint.INNER and rnumel >= 256: 1557 configs = configs[:1] 1558 elif reduction_hint == ReductionHint.OUTER: 1559 configs = configs[-1:] 1560 elif reduction_hint == ReductionHint.OUTER_TINY: 1561 configs = [ 1562 triton_config_reduction( 1563 size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel 1564 ) 1565 ] 1566 for c in configs: 1567 # we don't need RBLOCK for persistent reduction 1568 c.kwargs.pop("RBLOCK") 1569 1570 if disable_pointwise_autotuning(inductor_meta): 1571 configs = configs[:1] 1572 1573 return cached_autotune( 1574 size_hints, 1575 configs, 1576 triton_meta=triton_meta, 1577 inductor_meta=inductor_meta, 1578 filename=filename, 1579 heuristic_type=HeuristicType.PERSISTENT_REDUCTION, 1580 ) 1581 1582 1583def split_scan( 1584 size_hints, 1585 reduction_hint=False, 1586 triton_meta=None, 1587 filename=None, 1588 inductor_meta=None, 1589): 1590 """Heuristic for TritonSplitScanKernel""" 1591 inductor_meta = {} if inductor_meta is None else inductor_meta 1592 inductor_meta["reduction_hint"] = reduction_hint 1593 if inductor_meta.get("no_x_dim"): 1594 size_hints = [1, *size_hints[1:]] 1595 1596 assert triton_meta is not None 1597 if len(size_hints) != 2: 1598 raise NotImplementedError(f"size_hints: {size_hints}") 1599 1600 configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) 1601 1602 # Fixup configs to enforce the minimum RBLOCK size 1603 min_rblock = inductor_meta.get("min_split_scan_rblock", 256) 1604 for cfg in configs: 1605 if cfg.kwargs["RBLOCK"] < min_rblock: 1606 cfg.kwargs["RBLOCK"] = min_rblock 1607 1608 return cached_autotune( 1609 size_hints, 1610 configs=configs, 1611 triton_meta=triton_meta, 1612 inductor_meta=inductor_meta, 1613 heuristic_type=HeuristicType.SPLIT_SCAN, 1614 filename=filename, 1615 ) 1616 1617 1618def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None): 1619 """ 1620 Compile a triton template 1621 """ 1622 return cached_autotune( 1623 None, 1624 [triton.Config({}, num_stages=num_stages, num_warps=num_warps)], 1625 triton_meta=triton_meta, 1626 inductor_meta=inductor_meta, 1627 heuristic_type=HeuristicType.TEMPLATE, 1628 filename=filename, 1629 ) 1630 1631 1632def user_autotune( 1633 configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False 1634): 1635 """ 1636 Compile a user defined triton kernel 1637 """ 1638 defaults = inspect.signature(triton.Config).parameters 1639 default_num_stages = defaults["num_stages"].default 1640 default_num_warps = defaults["num_warps"].default 1641 1642 if len(configs) == 0: 1643 configs = [ 1644 triton.Config( 1645 {}, num_stages=default_num_stages, num_warps=default_num_warps 1646 ) 1647 ] 1648 else: 1649 configs = [ 1650 triton.Config( 1651 c.get("kwargs", {}), 1652 num_stages=c.get("num_stages", default_num_stages), 1653 num_warps=c.get("num_warps", default_num_warps), 1654 ) 1655 for c in configs 1656 ] 1657 1658 return cached_autotune( 1659 None, 1660 configs, 1661 triton_meta=triton_meta, 1662 heuristic_type=HeuristicType.USER_AUTOTUNE, 1663 filename=filename, 1664 inductor_meta=inductor_meta, 1665 custom_kernel=custom_kernel, 1666 ) 1667 1668 1669def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): 1670 """ 1671 Compile a triton foreach kernel 1672 """ 1673 return cached_autotune( 1674 None, 1675 [triton.Config({}, num_stages=1, num_warps=num_warps)], 1676 triton_meta=triton_meta, 1677 inductor_meta=inductor_meta, 1678 heuristic_type=HeuristicType.TEMPLATE, 1679 filename=filename, 1680 ) 1681 1682 1683def grid(*numels): 1684 """Helper function to compute triton grids""" 1685 if len(numels) == 1: 1686 xnumel, ynumel, znumel = numels[0], None, None 1687 elif len(numels) == 2: 1688 xnumel, ynumel, znumel = numels[1], numels[0], None 1689 elif len(numels) == 3: 1690 xnumel, ynumel, znumel = numels[2], numels[1], numels[0] 1691 else: 1692 raise AssertionError(f"invalid size for numels {len(numels)}") 1693 1694 def get_grid_dim(numel, block): 1695 if numel is None: 1696 return 1 1697 if block is None: 1698 return numel 1699 return ceildiv(numel, block) 1700 1701 def grid_fn(meta): 1702 x_grid = get_grid_dim(xnumel, meta.get("XBLOCK", 1)) 1703 y_grid = get_grid_dim(ynumel, meta.get("YBLOCK", None)) 1704 1705 max_y_grid = get_max_y_grid() 1706 if znumel is None: 1707 div = ceildiv(y_grid, max_y_grid) 1708 y_grid = ceildiv(y_grid, div) 1709 z_grid = div 1710 else: 1711 z_grid = get_grid_dim(znumel, meta.get("ZBLOCK", None)) 1712 torch._check( 1713 y_grid <= max_y_grid, 1714 lambda: f"Generated y grid beyond 2^16 ({y_grid}) not supported with z dimension present. File issue", 1715 ) 1716 1717 return ( 1718 x_grid, 1719 y_grid, 1720 z_grid, 1721 ) 1722 1723 setattr(grid_fn, "grid_fn_str", f"grid{numels}") # noqa: B010 1724 1725 return grid_fn 1726 1727 1728def split_scan_grid(xnumel, rnumel): 1729 def grid_fn(meta): 1730 assert meta.get("XBLOCK", 1) == 1 1731 return (ceildiv(rnumel, meta.get("RBLOCK", 1)), xnumel, 1) 1732 1733 grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})" 1734 setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010 1735 1736 return grid_fn 1737 1738 1739def grid_combo_kernels( 1740 *numels, num_kernels, min_blocks, is_sequential, default_meta=None 1741): 1742 """min_blocks is the minimal size of the grid x dimension""" 1743 if not is_sequential: 1744 # round robin dispatch 1745 numels_agg = list(numels) 1746 for i in range(len(numels_agg)): 1747 if isinstance(numels_agg[i], (list, tuple)): 1748 numels_agg[i] = max(max(numels_agg[i]), 0) # noqa: PLW3301 1749 kernel_grid_fn = grid(*numels_agg) 1750 1751 if isinstance(numels[-1], (list, tuple)): 1752 min_blocks_d = max(-min(numels[-1]), 0) * num_kernels 1753 else: 1754 min_blocks_d = None 1755 if min_blocks is None: 1756 assert min_blocks_d is not None 1757 min_blocks = min_blocks_d 1758 else: 1759 assert ( 1760 min_blocks_d is None or min_blocks == min_blocks_d 1761 ), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}" 1762 else: 1763 # sequential dispatch 1764 seq_numels = list(numels) 1765 # x numels are not used here, just a place holder 1766 seq_numels[-1] = 1024 1767 for i in range(len(seq_numels) - 1): 1768 if isinstance(seq_numels[i], (list, tuple)): 1769 seq_numels[i] = max(seq_numels[i]) 1770 1771 kernel_grid_fn = grid(*seq_numels) 1772 1773 def get_grid_dim(numel, block): 1774 if numel is None: 1775 return 1 1776 if block is None: 1777 return numel 1778 return ceildiv(numel, block) 1779 1780 def grid_fn(meta): 1781 assert min_blocks is not None, "min_blocks must be a number" 1782 cuda_grid = list(kernel_grid_fn(meta)) 1783 cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks) 1784 return tuple(cuda_grid) 1785 1786 def seq_grid_fn(meta): 1787 cuda_grid = list(kernel_grid_fn(meta)) 1788 # x <= 0 means this kernel's x grid is not tunable (x_no_dim is true) 1789 x_grid = sum( 1790 [ 1791 -x if x <= 0 else get_grid_dim(x, meta.get("XBLOCK", 1)) 1792 for x in numels[-1] 1793 ] 1794 ) 1795 cuda_grid[0] = x_grid 1796 return tuple(cuda_grid) 1797 1798 def grid_fn_default_meta(meta): 1799 return grid_fn(default_meta) 1800 1801 def seq_grid_fn_default_meta(meta): 1802 return seq_grid_fn(default_meta) 1803 1804 if default_meta is None: 1805 return grid_fn if not is_sequential else seq_grid_fn 1806 else: 1807 return grid_fn_default_meta if not is_sequential else seq_grid_fn_default_meta 1808