1import itertools 2import logging 3import textwrap 4from collections import defaultdict 5from dataclasses import dataclass 6from typing import ( 7 Any, 8 Callable, 9 cast, 10 Dict, 11 Iterable, 12 List, 13 Optional, 14 Tuple, 15 Type, 16 Union, 17) 18 19from sympy import Integer, Symbol 20 21from torch.utils._ordered_set import OrderedSet 22 23from .. import config, metrics 24from ..runtime.hints import DeviceProperties, ReductionHint 25from ..runtime.runtime_utils import next_power_of_2 26from ..runtime.triton_heuristics import grid_combo_kernels 27from ..scheduler import BaseSchedulerNode 28from ..utils import Placeholder 29from ..virtualized import V 30from .common import ( 31 DeferredLine, 32 IndentedBuffer, 33 Kernel, 34 PythonPrinter, 35 SizeArg, 36 WorkspaceArg, 37) 38from .simd import SIMDScheduling 39from .triton import gen_common_triton_imports, TritonKernel 40from .triton_utils import config_of, signature_to_meta 41 42 43log = logging.getLogger(__name__) 44pexpr = PythonPrinter().doprint 45LARGE_NUMELS = 512e5 46BLOCK_UTILIZATION = 0.8 47 48 49def _default_custom_combo_kernel_horizontal_partition( 50 nodes: List[BaseSchedulerNode], 51 triton_scheduling: SIMDScheduling, 52 kernel_map: Dict[BaseSchedulerNode, TritonKernel], 53 node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]], 54) -> List[List[BaseSchedulerNode]]: 55 """Horizontally partition the given list of nodes into a list of list of nodes where each sublist 56 represents a partion. Nodes in different partitions are implemented in different combo kernels. 57 Nodes in the same partition are likely to be implemented 58 in the same combo kernel, but subject to subsequent restrictions like CUDA limits for number of args. 59 60 Input arguments: 61 nodes: a list of fused scheduler nodes to partition. 62 triton_scheduling: TritonScheduling instance. 63 kernel_map: a map from node to its kernel. 64 node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel). 65 Output: 66 a list of list of nodes with each sublist representing a partition. 67 68 The default algorithm is to partition nodes based on the following rules: 69 1) nodes with the same number of block dimensions are grouped together. 70 2) large pointwise nodes (numels greater than LARGE_NUMELS) are separated from other nodes. 71 3) large reduce nodes are separated from other nodes. 72 """ 73 74 assert len(nodes) >= 1 75 76 # first partition nodes based on number of block dimensions 77 tilings = [node_info_map[n][1] for n in nodes] 78 79 max_dims = max(len(t) for t in tilings) 80 nodes_per_ndim = [] 81 for i in range(2, max_dims + 1): 82 group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i] 83 reduction = [ 84 n 85 for n in group_per_dim 86 if kernel_map[n].inside_reduction 87 and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim) 88 ] 89 not_reduction = [n for n in group_per_dim if n not in reduction] 90 # rnumel > 2048 usually has long execution time 91 # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes 92 long_reduction = [ 93 n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 94 ] 95 short_reduction = [n for n in reduction if n not in long_reduction] 96 if long_reduction: 97 log.warning( 98 "ComboKernels: %d long reduction nodes are separated", 99 len(long_reduction), 100 ) 101 large_pointwise = [ 102 n 103 for n in not_reduction 104 if not kernel_map[n].inside_reduction 105 and len(kernel_map[n].numels) == 2 106 and V.graph.sizevars.size_hint(kernel_map[n].numels[0]) > LARGE_NUMELS 107 ] 108 if large_pointwise: 109 # TODO benchmark the performance when large pointwise nodes combining with others 110 log.warning( 111 "ComboKernels: %d large pointwise nodes are separated", 112 len(large_pointwise), 113 ) 114 not_reduction = [n for n in not_reduction if n not in large_pointwise] 115 for node in large_pointwise: 116 nodes_per_ndim.append([node]) 117 118 for g in (not_reduction, short_reduction, long_reduction): 119 if g: 120 nodes_per_ndim.append(g) 121 122 assert sum(len(p) for p in nodes_per_ndim) == len(nodes) 123 return nodes_per_ndim 124 125 126_custom_combo_kernel_horizontal_partition_algorithm: Callable[ 127 [ 128 List[BaseSchedulerNode], 129 SIMDScheduling, 130 Dict[BaseSchedulerNode, TritonKernel], 131 Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]], 132 ], 133 List[List[BaseSchedulerNode]], 134] = _default_custom_combo_kernel_horizontal_partition 135 136 137def set_custom_combo_kernel_horizontal_partition( 138 algorithm: Callable[ 139 [ 140 List[BaseSchedulerNode], 141 SIMDScheduling, 142 Dict[BaseSchedulerNode, TritonKernel], 143 Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]], 144 ], 145 List[List[BaseSchedulerNode]], 146 ] 147) -> None: 148 """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions 149 are implemented in different combo kernels. Nodes in the same partition are likely to be implemented 150 in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args. 151 152 The algorithm should take a list of nodes and return a list of list of nodes. 153 154 The default algorithm is to partition nodes based on number of block dimensions. 155 """ 156 global _custom_combo_kernel_horizontal_partition_algorithm 157 _custom_combo_kernel_horizontal_partition_algorithm = algorithm 158 159 160@dataclass 161class PartitionState: 162 partitions: List[List[BaseSchedulerNode]] 163 cur_partition: List[BaseSchedulerNode] 164 cur_count: int 165 166 def finalize(self) -> None: 167 if self.cur_partition: 168 self.partitions.append(self.cur_partition) 169 170 171class ComboKernel(Kernel): 172 MAX_NUM_ARGS = 250 # number where I would no longer get triton errors 173 174 @staticmethod 175 def _update_partition( 176 partition_state: PartitionState, 177 node_rw_count: int, 178 node_info: BaseSchedulerNode, 179 ) -> None: 180 if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS: 181 partition_state.partitions.append(partition_state.cur_partition) 182 partition_state.cur_partition = [node_info] 183 partition_state.cur_count = node_rw_count 184 else: 185 partition_state.cur_count += node_rw_count 186 partition_state.cur_partition.append(node_info) 187 188 @staticmethod 189 def _base_horizontal_partition( 190 subkernel_nodes: List[BaseSchedulerNode], 191 triton_scheduling: SIMDScheduling, 192 node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]], 193 custom_algorithm: bool, 194 ) -> List[List[BaseSchedulerNode]]: 195 """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) 196 for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args 197 (read/writes) and to have the same 2D or 1D blocking strategy.""" 198 # TODO support combination of kernels with different block dimensions 199 assert len(subkernel_nodes) >= 1 200 mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or ( 201 config.combo_kernel_allow_mixed_sizes == 1 and custom_algorithm 202 ) 203 204 ndim_to_partition_state: Dict[int, PartitionState] = defaultdict( 205 lambda: PartitionState([], [], 0) 206 ) 207 yelem_to_partition_state: Dict[int, PartitionState] = defaultdict( 208 lambda: PartitionState([], [], 0) 209 ) 210 211 for node in subkernel_nodes: 212 node_schedule, tiled_groups, numel, rnumel = node_info_map[node] 213 node_info = node 214 215 read_writes = node.read_writes 216 read_write_count = len(read_writes.reads) + len(read_writes.writes) 217 218 ndim = len(tiled_groups) 219 assert ndim >= 2, f"Combokernel not support tile {tiled_groups}" 220 if not mixed_sizes and ndim == 3: 221 y_elem = tiled_groups[0] 222 partition_state = yelem_to_partition_state[y_elem] 223 ComboKernel._update_partition( 224 partition_state, read_write_count, node_info 225 ) 226 else: 227 assert mixed_sizes or ndim <= 3, f"No mixed sizes: tile {tiled_groups}" 228 partition_state = ndim_to_partition_state[ndim] 229 ComboKernel._update_partition( 230 partition_state, read_write_count, node_info 231 ) 232 233 all_partitions = [] 234 for partition_state in ndim_to_partition_state.values(): 235 partition_state.finalize() 236 all_partitions.extend(partition_state.partitions) 237 for partition_state in yelem_to_partition_state.values(): 238 partition_state.finalize() 239 all_partitions.extend(partition_state.partitions) 240 241 return all_partitions 242 243 @staticmethod 244 def horizontal_partition( 245 nodes: List[BaseSchedulerNode], 246 triton_scheduling: SIMDScheduling, 247 kernel_map: Dict[BaseSchedulerNode, TritonKernel], 248 node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]], 249 custom_algorithm: bool = False, 250 ) -> List[List[BaseSchedulerNode]]: 251 """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum) 252 for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into 253 sublists in the following way: 254 1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True 255 2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is 256 guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same 257 2D or 1D blocking strategy. 258 """ 259 if custom_algorithm: 260 raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm( 261 nodes, triton_scheduling, kernel_map, node_info_map 262 ) 263 else: 264 raw_partitions = [nodes] 265 266 """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) 267 for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args 268 (read/writes) and to have the same 2D or 1D blocking strategy.""" 269 all_partitions = [] 270 for raw_partition in raw_partitions: 271 all_partitions.extend( 272 ComboKernel._base_horizontal_partition( 273 raw_partition, triton_scheduling, node_info_map, custom_algorithm 274 ) 275 ) 276 return all_partitions 277 278 class SequentialDispatch: 279 """ 280 The dispatcher which dispatches the subkernels in a sequential manner: 281 the blocks are first dispatched to the 1st subkernel (until it is filled), 282 then to the 2nd subkernel, and so on. 283 The class defines the methods specific to the dispatch algorithm. 284 Methods: 285 codegen_pid_range(...): codegen the pid range for each subkernel. 286 grid(...): codegen the grid size for launching the combo kernel. 287 """ 288 289 @classmethod 290 def codegen_pid_range( 291 cls, kernel: "ComboKernel", num: int, code: IndentedBuffer 292 ) -> None: 293 if num == 0: 294 cls._calculate_xblocks(kernel, code) 295 code.splice(f"if pid < num_xblocks_{num}:") 296 with code.indent(): 297 code.splice("pid_offset = pid") 298 else: 299 code.splice(f"elif pid < num_xblocks_{num}:") 300 with code.indent(): 301 code.splice(f"pid_offset = pid - num_xblocks_{num-1}") 302 303 @classmethod 304 def _calculate_xblocks( 305 cls, kernel: "ComboKernel", code: IndentedBuffer 306 ) -> None: 307 x_numels_list = kernel.x_numels_list 308 for i in range(len(x_numels_list)): 309 xnumels, no_x_dim = ( 310 (x_numels_list[i], False) 311 if isinstance(x_numels_list[i], str) 312 and cast(str, x_numels_list[i])[0] != "-" 313 or ( 314 isinstance(x_numels_list[i], int) 315 and cast(int, x_numels_list[i]) > 0 316 ) 317 else (kernel.min_x_blocks_list[i], True) 318 ) 319 xblock_str = ( 320 f"tl.cdiv({xnumels}, XBLOCK)" if not no_x_dim else f"{xnumels}" 321 ) 322 if i == 0: 323 code.splice(f"num_xblocks_{i} = {xblock_str}") 324 else: 325 code.splice(f"num_xblocks_{i} = num_xblocks_{i-1} + {xblock_str}") 326 327 @classmethod 328 def grid( 329 cls, 330 sub_kernel_numels: List[List[int]], 331 x_blocks_list: List[Union[str, int]], 332 dynamic_shape: bool, 333 ) -> Tuple[Any, ...]: 334 xnumel = list(x_blocks_list) 335 ynumel: Any = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels] 336 znumel: Any = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels] 337 338 if dynamic_shape: 339 ynumel = None if None in ynumel else ynumel 340 znumel = None if None in znumel else znumel 341 else: 342 # TODO: improve 1d/2d mixed cases 343 ynumel = ( 344 None 345 if any(e is None for e in cast(List[Any], ynumel)) 346 else max(cast(Iterable[int], ynumel)) 347 ) 348 znumel = ( 349 None 350 if any(e is None for e in cast(List[Any], znumel)) 351 else max(cast(Iterable[int], znumel)) 352 ) 353 354 numels = ( 355 (xnumel,) 356 if not ynumel 357 else (ynumel, xnumel) 358 if not znumel 359 else (znumel, ynumel, xnumel) 360 ) 361 return numels 362 363 class RoundRobinDispatch: 364 """ 365 The dispatcher which dispatches the subkernels in a round robin manner: 366 the blocks are interleavedly dispatched to each subkernel to execute them 367 in parallel. 368 The class defines the methods specific to the dispatch algorithm. 369 Methods: 370 codegen_pid_range(...): codegen the pid range for each subkernel. 371 grid(...): codegen the grid size for launching the combo kernel. 372 """ 373 374 @classmethod 375 def codegen_pid_range( 376 cls, kernel: "ComboKernel", num: int, code: IndentedBuffer 377 ) -> None: 378 num_kernels = len(kernel.sub_kernels) 379 if num == 0: 380 cond = "if" 381 else: 382 cond = "elif" 383 code.splice(f"{cond} pid % {num_kernels} == {num}:") 384 with code.indent(): 385 code.splice(f"pid_offset = pid // {num_kernels}") 386 387 @classmethod 388 def grid( 389 cls, 390 sub_kernel_numels: List[List[int]], 391 x_blocks_list: List[Union[str, int]], 392 dynamic_shape: bool, 393 ) -> Tuple[Any, ...]: 394 xnumel = x_blocks_list 395 # set no_x_dim xnumels to 0 396 xnumel_x_dim = [max(e, 0) for e in xnumel] 397 ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels] 398 znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels] 399 400 # TODO: support 1d/2d mixed cases 401 xnumel = ( 402 None 403 if any(e is None for e in xnumel) 404 else xnumel 405 if dynamic_shape 406 else max(xnumel_x_dim) # type: ignore[type-var, arg-type] 407 ) 408 ynumel = ( 409 None 410 if any(e is None for e in ynumel) 411 else ynumel 412 if dynamic_shape 413 else max(ynumel) # type: ignore[type-var, arg-type] 414 ) 415 znumel = ( 416 None 417 if any(e is None for e in znumel) 418 else znumel 419 if dynamic_shape 420 else max(znumel) # type: ignore[type-var, arg-type] 421 ) 422 423 numels = ( 424 (xnumel,) 425 if not ynumel 426 else (ynumel, xnumel) 427 if not znumel 428 else (znumel, ynumel, xnumel) 429 ) 430 return numels 431 432 def __init__( 433 self, enable_autotune: bool = False, mixed_sizes: bool = False 434 ) -> None: 435 super().__init__() 436 self.sub_kernels: List[TritonKernel] = [] 437 self.iter_vars_count = itertools.count() 438 self.grids: List[List[int]] = [] 439 self.min_x_blocks_list: List[Union[int, str]] = [] 440 self.x_numels_list: List[Union[int, str]] = [] 441 self.enable_autotune = enable_autotune 442 self.mixed_sizes = mixed_sizes 443 self.dispatch_class: Optional[ 444 Union[ 445 Type[ComboKernel.SequentialDispatch], 446 Type[ComboKernel.RoundRobinDispatch], 447 ] 448 ] = None 449 self.block_args: List[str] = [] 450 # there following are used when autotuning is disabled 451 self.block_size_1d = 1024 # Try tuning this value 452 self.block_size_2d = 32 453 self.num_warps = 8 454 self.block_size_reduce = 256 455 self.dynamic_shape_args: List[str] = [] 456 457 def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel: 458 sub_kernel = triton_kernel 459 metrics.generated_kernel_count -= 1 460 sub_kernel.args = self.args 461 sub_kernel.iter_vars_count = self.iter_vars_count 462 sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids 463 self.sub_kernels.append(sub_kernel) 464 return sub_kernel 465 466 @staticmethod 467 def create_triton_kernel( 468 *groups: Any, 469 index_dtype: str, 470 mutations: OrderedSet[str], 471 reduction_hint: ReductionHint, 472 optimize_mask: bool, 473 ) -> TritonKernel: 474 """ 475 Only allow optimize_mask=True when 1) sequential dispatch is used, 476 2) numels except x dimension are the same for each sub kernel. 477 """ 478 return TritonKernel( 479 *groups, 480 index_dtype=index_dtype, 481 mutations=mutations, 482 pid_cache={"tl.program_id(0)": "pid_offset"}, 483 reduction_hint=reduction_hint, 484 optimize_mask=optimize_mask, 485 ) 486 487 def codegen_static_numels_sub_kernel( 488 self, code: IndentedBuffer, sub_kernel: TritonKernel, num: int 489 ) -> List[str]: 490 """ 491 We get a small speedup from hard coding numels if they are static. 492 493 This code stomps on the passed-in values by writing an constant to the top of the kernel. 494 495 In a kernel like: 496 def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): 497 498 We would add 499 xnumel = 4096 500 rnumel = 768 501 502 After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes 503 a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream 504 knows that its a static numel, as that you just plop a constant into the kernel. 505 """ 506 grid = [] 507 uniquify_block_sizes = [] 508 for tree in sub_kernel.range_trees: 509 simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) 510 if isinstance(simplified_tree_numel, (Integer, int)): 511 code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") 512 else: 513 assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args 514 uniquify_block_sizes.append(f"{tree.prefix}numel") 515 516 if tree.prefix != "r": 517 if isinstance(simplified_tree_numel, (Integer, int)): 518 grid.append(int(simplified_tree_numel)) 519 else: 520 grid.append(f"{tree.prefix}numel_{num}") 521 522 if tree.prefix == "r" and sub_kernel.persistent_reduction: 523 if isinstance(simplified_tree_numel, (Integer, int)): 524 val = int(simplified_tree_numel) 525 else: 526 raise RuntimeError( 527 "Dynamic shape on reduction dimension is not supported" 528 ) 529 val = next_power_of_2(val) 530 code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}") 531 uniquify_block_sizes.append("RBLOCK") 532 533 if tree.prefix == "x" and sub_kernel.no_x_dim: 534 code.writeline(f"XBLOCK_{num}: tl.constexpr = 1") 535 uniquify_block_sizes.append("XBLOCK") 536 self.grids.append(grid) 537 return uniquify_block_sizes 538 539 def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None: 540 """ 541 Kernels with no_x_dim being true has no tunable XBLOCK. They have a fixed number of X blocks. 542 Grid calculation needs to make sure that they are assigned with enough number of blocks. 543 """ 544 min_x_blocks: Union[int, str] = 0 545 x_numels: Union[int, str] = 0 546 for tree in sub_kernel.range_trees: 547 simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) 548 if tree.prefix == "x": 549 if isinstance(simplified_tree_numel, (Integer, int)): 550 x_numels = int(simplified_tree_numel) 551 else: 552 x_numels = f"{tree.prefix}numel_{num}" 553 if sub_kernel.no_x_dim: 554 min_x_blocks = x_numels 555 x_numels = ( 556 -min_x_blocks 557 if isinstance(x_numels, int) 558 else "-" + cast(str, x_numels) 559 ) 560 else: 561 if isinstance(simplified_tree_numel, (Integer, int)): 562 x_numels = int(simplified_tree_numel) 563 else: 564 x_numels = f"{tree.prefix}numel_{num}" 565 self.min_x_blocks_list.append(min_x_blocks) 566 self.x_numels_list.append(x_numels) 567 568 def select_heuristics(self, sub_kernel: TritonKernel) -> Tuple[str, List[int]]: 569 size_hints = [ 570 next_power_of_2(V.graph.sizevars.size_hint(numel)) 571 for numel in sub_kernel.numels 572 ] 573 if sub_kernel.persistent_reduction: 574 assert sub_kernel.inside_reduction 575 heuristics = "persistent_reduction" 576 elif sub_kernel.inside_reduction: 577 heuristics = "reduction" 578 else: 579 size_hints.pop() 580 heuristics = "pointwise" 581 return heuristics, size_hints 582 583 def select_combo_heuristics( 584 self, heuristics_list: List[str], size_hints_list: List[List[int]] 585 ) -> Tuple[str, List[int], TritonKernel]: 586 if not self.enable_autotune: 587 return "foreach", size_hints_list[0], self.sub_kernels[0] 588 if "reduction" in heuristics_list: 589 i, _ = max( 590 enumerate(size_hints_list), 591 key=lambda x: x[1][0] if heuristics_list[x[0]] == "reduction" else 0, 592 ) 593 return heuristics_list[i], size_hints_list[i], self.sub_kernels[i] 594 elif "pointwise" in heuristics_list: 595 i, _ = max( 596 enumerate(size_hints_list), 597 key=lambda x: x[1][0] if heuristics_list[x[0]] == "pointwise" else 0, 598 ) 599 # modify size_hint to avoid oom check fail (may be a false alarm) 600 num_pointwise = len([e for e in heuristics_list if e == "pointwise"]) 601 num_reduction = len([e for e in heuristics_list if e == "reduction"]) 602 num_persistent_reduction = len( 603 [e for e in heuristics_list if e == "persistent_reduction"] 604 ) 605 assert ( 606 num_reduction == 0 607 ), "combining pointwise and reduction are not supported yet." 608 heuristics = ( 609 "pointwise_with_reduction" 610 if num_persistent_reduction > 0 611 else "pointwise" 612 ) 613 if len(heuristics_list) - num_pointwise >= 4: 614 size_hints = size_hints_list[i] 615 size_hints[0] = min(128, size_hints[0]) 616 return heuristics, size_hints_list[i], self.sub_kernels[i] 617 else: 618 return heuristics_list[0], size_hints_list[0], self.sub_kernels[0] 619 620 def get_mutated_args_sub_kernels(self) -> List[str]: 621 mutated_args = set() 622 for sub_kernel in self.sub_kernels: 623 for mutation in sub_kernel.mutations: 624 if mutation in sub_kernel.args.input_buffers: 625 mutated_args.add(sub_kernel.args.input_buffers[mutation]) 626 if ( 627 mutation in sub_kernel.args.inplace_buffers 628 and mutation not in V.graph.removed_buffers 629 and mutation not in sub_kernel.removed_buffers 630 ): 631 mutated_args.add( 632 sub_kernel.args.inplace_buffers[mutation].inner_name 633 ) 634 if mutation in sub_kernel.args.output_buffers: 635 mutated_args.add(sub_kernel.args.output_buffers[mutation]) 636 return sorted(mutated_args) 637 638 def select_dispatch_strategy(self) -> None: 639 if self.dispatch_class is not None: 640 return 641 # mixed_sizes is used for optimize_mask, so it only allows sequential dispatch 642 # Not mixed sizes on y dim technically is ok to use round robin as wells. 643 if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list): 644 # str in min_x_blocks_list means a dynamic shape 645 self.dispatch_class = ComboKernel.SequentialDispatch 646 return 647 # A negative x_blocks_list element means the kernel is not tunable, 648 # i.e., no_x_dim = True 649 x_numels_list = [abs(cast(int, e)) for e in self.x_numels_list] 650 total = max(x_numels_list) * len(x_numels_list) 651 needed = sum(x_numels_list) 652 if needed / total > BLOCK_UTILIZATION: 653 # Introduced overhead (masked blocks) is less than 20% 654 self.dispatch_class = ComboKernel.RoundRobinDispatch 655 else: 656 self.dispatch_class = ComboKernel.SequentialDispatch 657 658 def jit_line( 659 self, 660 heuristics: str, 661 size_hints: List[int], 662 selected_kernel: TritonKernel, 663 pointwise_with_reduce: bool = False, 664 signature: Optional[List[Any]] = None, 665 ) -> str: 666 can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) 667 size_dtype = "tl.int32" if can_use_32bit else "tl.int64" 668 if signature is None: 669 _, _, signature, _ = self.args.python_argdefs() 670 for i, sub in enumerate(self.sub_kernels): 671 self.min_x_blocks_sub_kernel(sub, i) 672 self.select_dispatch_strategy() 673 triton_meta = { 674 "signature": signature_to_meta(signature, size_dtype=size_dtype), 675 "device": DeviceProperties.create( 676 V.graph.scheduler.get_current_device_or_throw() 677 ), 678 "constants": {}, 679 } 680 triton_meta["configs"] = [config_of(signature)] 681 mutated_args = self.get_mutated_args_sub_kernels() 682 inductor_meta = { 683 "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), 684 "mutated_arg_names": mutated_args, 685 **TritonKernel.inductor_meta_common(), 686 } 687 688 sub_kernel = selected_kernel 689 if heuristics == "foreach": 690 heuristics_line = f""" 691 @triton_heuristics.foreach( 692 num_warps={self.num_warps}, 693 triton_meta={triton_meta!r}, 694 inductor_meta={inductor_meta!r}, 695 ) 696 @triton.jit 697 """ 698 elif sub_kernel.inside_reduction: 699 reduction_hint = sub_kernel.reduction_hint 700 heuristics_line = f""" 701 @triton_heuristics.{heuristics}( 702 size_hints={size_hints!r}, 703 reduction_hint={reduction_hint}, 704 filename=__file__, 705 triton_meta={triton_meta!r}, 706 inductor_meta={inductor_meta!r} 707 ) 708 @triton.jit 709 """ 710 else: 711 tile_hint = "" 712 if len(size_hints) == 2: 713 tile_hint = "tile_hint=TileHint.SQUARE," 714 else: 715 tile_hint = "tile_hint=TileHint.DEFAULT," 716 heuristics_line = f""" 717 @triton_heuristics.{heuristics}( 718 size_hints={size_hints!r}, {tile_hint} 719 filename=__file__, 720 triton_meta={triton_meta!r}, 721 inductor_meta={inductor_meta!r} 722 ) 723 @triton.jit 724 """ 725 726 return heuristics_line 727 728 def codegen_blocks(self, code: IndentedBuffer) -> None: 729 for block in self.block_args: 730 assert block in [ 731 "XBLOCK", 732 "YBLOCK", 733 "RBLOCK", 734 ], f"{block} is not supported without autotuning" 735 if "YBLOCK" in self.block_args: 736 code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}") 737 code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}") 738 else: 739 code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}") 740 if "RBLOCK" in self.block_args: 741 code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}") 742 743 def add_blockd_to_args(self, argdefs: List[str]) -> List[str]: 744 block_args = {} 745 block_names = {} 746 for num, sub_kernel in enumerate(self.sub_kernels): 747 # TODO: we assume all sub_kernels have the same block size 748 for tree in sub_kernel.range_trees: 749 if tree.prefix == "r" and ( 750 not sub_kernel.inside_reduction or sub_kernel.persistent_reduction 751 ): 752 continue 753 if tree.prefix == "x" and sub_kernel.no_x_dim: 754 continue 755 block_args[f"{tree.prefix.upper()}BLOCK : tl.constexpr"] = tree.prefix 756 block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix 757 if self.enable_autotune: 758 argdefs.extend(block_args) 759 self.block_args = list(block_names.keys()) 760 return argdefs 761 762 def add_numel_to_args(self, argdefs: List[str], signature: List[Any]) -> List[str]: 763 for num, sub_kernel in enumerate(self.sub_kernels): 764 for tree in sub_kernel.active_range_trees(): 765 if not isinstance(tree.numel, (Integer, int)): 766 # only if it is a dynamic shape 767 sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel) 768 signature.append(sizearg) 769 argdefs.append(f"{tree.prefix}numel_{num}") 770 self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}") 771 return argdefs 772 773 def add_numel_to_call_args_and_grid( 774 self, name: str, call_args: List[Any], arg_types: List[Any], grid: List[Any] 775 ) -> None: 776 for num, sub_kernel in enumerate(self.sub_kernels): 777 for i, tree in enumerate(sub_kernel.range_trees): 778 numel_name = f"{tree.prefix}numel_{num}" 779 if numel_name not in self.dynamic_shape_args: 780 continue 781 if isinstance(tree.numel, (Integer, Symbol)): 782 expr = tree.numel 783 else: 784 expr = V.graph.wrapper_code.generate_numel_expr( 785 name, tree, suffix=str(num) 786 ) 787 if tree.prefix != "r": 788 assert isinstance( 789 grid[i][num], str 790 ), f"Grid {grid[i][num]} should be a dynamic shape." 791 numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" 792 assert ( 793 grid[i][num] == numel_sign + numel_name 794 ), f"numel args mismatch: {grid[i][num]} vs {numel_name}" 795 grid[i][num] = -expr if numel_sign == "-" else expr 796 797 if tree.prefix != "r" or sub_kernel.inside_reduction: 798 call_args.append(expr) 799 arg_types.append(type(expr)) 800 801 def add_numel_to_call_args_and_grid_benchmark( 802 self, extra_args: List[Any], grid: Union[List[Any], Tuple[Any, ...]] 803 ) -> None: 804 for num, sub_kernel in enumerate(self.sub_kernels): 805 for i, tree in enumerate(sub_kernel.range_trees): 806 numel_name = f"{tree.prefix}numel_{num}" 807 if numel_name not in self.dynamic_shape_args: 808 continue 809 expr = V.graph.sizevars.size_hint(tree.numel) 810 if tree.prefix != "r": 811 assert isinstance( 812 grid[i][num], str 813 ), f"Grid {grid[i][num]} should be a dynamic shape." 814 numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" 815 assert ( 816 grid[i][num] == numel_sign + numel_name 817 ), f"grid mismatch: {grid[i][num]} vs {numel_name}" 818 grid[i][num] = -expr if numel_sign == "-" else expr 819 if tree.prefix != "r" or sub_kernel.inside_reduction: 820 extra_args.append(expr) 821 822 def codegen_kernel(self, name: Optional[str] = None) -> str: 823 # TODO: is it correct to use the first sub kernel's heuristics? 824 heuristics_list, size_hints_list = [], [] 825 for subkernel in self.sub_kernels: 826 h, s = self.select_heuristics(subkernel) 827 heuristics_list.append(h) 828 size_hints_list.append(s) 829 heuristics, size_hints, selected_kernel = self.select_combo_heuristics( 830 heuristics_list, size_hints_list 831 ) 832 pointwise_with_reduction, heuristics = ( 833 (True, "pointwise") 834 if heuristics == "pointwise_with_reduction" 835 else (False, heuristics) 836 ) 837 code = IndentedBuffer() 838 839 code.splice(gen_common_triton_imports()) 840 if config.benchmark_combo_kernel: 841 code.splice(self.imports_for_benchmark_kernel()) 842 843 argdefs, _, signature, _ = self.args.python_argdefs() 844 argdefs = self.add_numel_to_args(argdefs, signature) 845 argdefs = self.add_blockd_to_args(argdefs) 846 code.splice( 847 self.jit_line( 848 heuristics, 849 size_hints, 850 selected_kernel, 851 pointwise_with_reduce=pointwise_with_reduction, 852 signature=signature, 853 ) 854 ) 855 code.writeline( 856 f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):" 857 ) 858 859 with code.indent(): 860 code.splice("pid = tl.program_id(0)") 861 if not self.enable_autotune: 862 self.codegen_blocks(code) 863 864 for num, sub_kernel in enumerate(self.sub_kernels): 865 assert self.dispatch_class is not None 866 self.dispatch_class.codegen_pid_range(self, num, code) 867 with code.indent(): 868 uniquify = self.codegen_static_numels_sub_kernel( 869 code, sub_kernel, num 870 ) 871 sub_kernel.codegen_body() 872 uniquified_body = self.uniquify_block_sizes( 873 sub_kernel.body, num, uniquify 874 ) 875 code.splice(uniquified_body) 876 877 code.splice("else:") 878 with code.indent(): 879 code.splice("pass") 880 881 if config.benchmark_combo_kernel: 882 code.splice(self.codegen_kernel_benchmark(num_gb=0)) 883 884 return code.getvalue() 885 886 def codegen_kernel_benchmark( 887 self, num_gb: float, grid: Optional[List[Any]] = None 888 ) -> IndentedBuffer: 889 result = IndentedBuffer() 890 argdefs, call_args, signature, _ = self.args.python_argdefs() 891 892 result.writelines(["", "", "def get_args():"]) 893 with result.indent(): 894 name_cnt = itertools.count() 895 var_names = [] 896 for arg_name, arg_sig in zip(call_args, signature): 897 var_name = f"arg_{next(name_cnt)}" 898 buf = V.graph.try_get_buffer(arg_name) 899 if buf: 900 result.writeline( 901 f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long 902 ) 903 elif arg_name in V.graph.constants: 904 # note that random seed is put in V.graph.constants 905 const_tensor = V.graph.constants[arg_name] 906 result.writeline( 907 f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long 908 ) 909 elif isinstance(arg_sig, SizeArg): 910 symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) 911 912 # Force the seed_offset to be 0 so calls to the same kernel 913 # using different seed offset will have the same benchmark harness. 914 # We can dedup kernel definitions in this case. 915 if "seed_offset" in arg_sig.name: 916 symval_hint = 0 917 result.writeline(f"{var_name} = {symval_hint}") 918 elif isinstance(arg_sig, WorkspaceArg): 919 device = V.graph.scheduler.get_current_device_or_throw() 920 nbytes = V.graph.sizevars.size_hint(arg_sig.nbytes) 921 result.writeline( 922 f"{var_name} = torch.zeros({nbytes}, device='{device}', dtype=torch.uint8)" 923 ) 924 else: 925 raise KeyError( 926 f"Don't find the buffer or const tensor for {arg_name}" 927 ) 928 var_names.append(var_name) 929 result.writeline(f"return {', '.join(var_names)},") 930 931 result.writelines(["\n", "\n", "def call(args):"]) 932 if grid is None: 933 assert self.dispatch_class is not None 934 dynamic_shape = self.dynamic_shape_args != [] 935 grid_tuple = self.dispatch_class.grid( 936 self.grids, self.x_numels_list, dynamic_shape 937 ) 938 extra_args_str = "" 939 extra_args: List[Any] = [] 940 if dynamic_shape: 941 self.add_numel_to_call_args_and_grid_benchmark(extra_args, grid_tuple) 942 # convert nested list to list of str 943 grid_tuple = tuple( 944 "[" + ", ".join(pexpr(item) for item in e) + ",]" 945 for e in grid_tuple 946 ) 947 extra_args_str = ", ".join(map(str, extra_args)) + ", " 948 min_blocks = None 949 else: 950 min_blocks = max(self.min_x_blocks_list) * len(self.sub_kernels) 951 grid_str = ", ".join(pexpr(item) for item in grid_tuple) 952 grid_extra_kwargs = ( 953 f"num_kernels={len(self.sub_kernels)}, " 954 f"min_blocks={min_blocks}, " 955 f"is_sequential={self.dispatch_class is self.SequentialDispatch}" 956 ) 957 grid_str = f"{grid_str}, {grid_extra_kwargs}" 958 grid_arg = f"{extra_args_str}grid=grid_combo_kernels({grid_str})" 959 else: 960 grid_arg = f"grid={grid}" 961 index = V.graph.scheduler.get_current_device_or_throw().index 962 with result.indent(): 963 result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") 964 with result.indent(): 965 result.writeline( 966 V.graph.device_ops.set_device(index) 967 ) # no-op to ensure context 968 stream_name = f"stream{index}" 969 result.writeline(f"{stream_name} = get_raw_stream({index})") 970 result.writeline( 971 f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})" 972 ) 973 974 # benchmark all configs 975 result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) 976 with result.indent(): 977 result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") 978 with result.indent(): 979 result.writeline( 980 V.graph.device_ops.set_device(index) 981 ) # no-op to ensure context 982 result.writeline( 983 f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})" 984 ) 985 986 result.writelines(["\n", "\n", "if __name__ == '__main__':"]) 987 with result.indent(): 988 result.writeline( 989 "from torch._inductor.runtime.benchmarking import benchmarker" 990 ) 991 result.writeline("") 992 993 result.writeline("args = get_args()") 994 result.writeline( 995 "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)" 996 ) 997 result.writeline(f"num_gb = {num_gb}") 998 result.writeline("gb_per_s = num_gb / (ms / 1e3)") 999 result.writeline( 1000 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' 1001 ) 1002 1003 return result 1004 1005 def imports_for_benchmark_kernel(self) -> str: 1006 return textwrap.dedent( 1007 """ 1008 from torch._dynamo.testing import rand_strided 1009 {} 1010 import torch 1011 from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels 1012 """.format( 1013 V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") 1014 ) 1015 ) 1016 1017 def uniquify_block_sizes( 1018 self, code: IndentedBuffer, num_kernel: int, uniquify: List[str] 1019 ) -> IndentedBuffer: 1020 if not uniquify: 1021 return code 1022 modified = IndentedBuffer(initial_indent=code._indent) 1023 for line in code._lines: 1024 if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]): 1025 modified_line = line 1026 for block in blocks: 1027 modified_line = modified_line.replace( 1028 block, f"{block}_{num_kernel}" 1029 ) 1030 modified.writeline(modified_line) 1031 elif isinstance(line, DeferredLine) and ( 1032 blocks := [e for e in uniquify if e in line.line] 1033 ): 1034 modified_line = line.line 1035 for block in blocks: 1036 modified_line = modified_line.replace( 1037 block, f"{block}_{num_kernel}" 1038 ) 1039 new_line = DeferredLine(line.name, modified_line) 1040 modified.writeline(new_line) 1041 else: 1042 modified.writeline(line) 1043 return modified 1044 1045 def call_kernel(self, code: IndentedBuffer, name: str) -> None: 1046 _, call_args, _, arg_types = self.args.python_argdefs() 1047 1048 wrapper = V.graph.wrapper_code 1049 assert self.dispatch_class is not None 1050 dynamic_shape = self.dynamic_shape_args != [] 1051 grid = list( 1052 self.dispatch_class.grid(self.grids, self.x_numels_list, dynamic_shape) 1053 ) 1054 num_kernels = len(self.sub_kernels) 1055 min_blocks = ( 1056 max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None 1057 ) 1058 is_sequential = self.dispatch_class is self.SequentialDispatch 1059 if dynamic_shape: 1060 self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid) 1061 # convert nested list to list of str 1062 # grid = tuple("["+", ".join(pexpr(item) for item in e)+",]" for e in grid) 1063 if not self.enable_autotune and not dynamic_shape: 1064 launch_grid = self.grid_no_autotune( 1065 grid, num_kernels, cast(int, min_blocks), is_sequential 1066 ) 1067 V.graph.wrapper_code.generate_kernel_call( 1068 name, 1069 call_args, 1070 grid=launch_grid, 1071 arg_types=arg_types, 1072 grid_fn="", 1073 ) 1074 return 1075 # autotuning is enabled 1076 grid = wrapper.generate_default_grid( 1077 name, 1078 list(grid), 1079 grid_callable=grid_combo_kernels, 1080 num_kernels=num_kernels, 1081 min_blocks=min_blocks, 1082 is_sequential=is_sequential, 1083 default_meta=None if self.enable_autotune else self.get_default_meta(), 1084 ) 1085 wrapper.generate_kernel_call( 1086 name, 1087 call_args, 1088 grid, 1089 V.graph.scheduler.get_current_device_or_throw().index, 1090 cuda=True, 1091 triton=True, 1092 arg_types=arg_types, 1093 grid_fn="grid_combo_kernels", 1094 grid_extra_kwargs=( 1095 f"num_kernels={num_kernels}, " 1096 f"min_blocks={min_blocks}, " 1097 f"is_sequential={is_sequential}, " 1098 f"default_meta={None if self.enable_autotune else self.get_default_meta()}" 1099 ), 1100 ) 1101 1102 def grid_no_autotune( 1103 self, 1104 grid: Union[Tuple[Any], List[Any]], 1105 num_kernels: int, 1106 min_blocks: int, 1107 is_sequential: bool, 1108 ) -> List[int]: 1109 meta = self.get_default_meta() 1110 grid_func = grid_combo_kernels( 1111 *grid, 1112 num_kernels=num_kernels, 1113 min_blocks=min_blocks, 1114 is_sequential=is_sequential, 1115 ) 1116 return grid_func(meta) 1117 1118 def get_default_meta(self) -> Dict[str, int]: 1119 if "YBLOCK" in self.block_args: 1120 meta = {"XBLOCK": self.block_size_2d, "YBLOCK": self.block_size_2d} 1121 else: 1122 meta = {"XBLOCK": self.block_size_1d} 1123 return meta 1124