1# mypy: disallow-untyped-defs 2from __future__ import annotations 3 4import collections 5import dataclasses 6import functools 7import itertools 8import logging 9import math 10import operator 11import os 12import pprint 13import textwrap 14import traceback 15import typing 16from typing import ( 17 Any, 18 Callable, 19 Counter, 20 DefaultDict, 21 Dict, 22 Generic, 23 List, 24 Optional, 25 Sequence, 26 Set, 27 Tuple, 28 TypeVar, 29 Union, 30) 31 32import sympy 33 34import torch 35import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools 36from torch._dynamo.utils import counters, dynamo_timed 37from torch._inductor.metrics import get_metric_table, is_metric_table_enabled 38from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols 39from torch.utils._ordered_set import OrderedSet 40from torch.utils._sympy.symbol import free_symbol_is_type, SymT 41from torch.utils._triton import has_triton 42 43from . import comms, config, dependencies, ir, metrics 44from .codecache import write_text 45from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel 46from .comm_analysis import estimate_nccl_collective_runtime 47from .dependencies import Dep, MemoryDep, StarDep, WeakDep 48from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout 49from .loop_body import LoopBody 50from .runtime.runtime_utils import green_text, red_text 51from .sizevars import SimplifyIndexing 52from .utils import ( 53 cache_on_self, 54 cmp, 55 device_need_guard, 56 get_device_tflops, 57 get_dtype_size, 58 get_gpu_dram_gbps, 59 IndentedBuffer, 60 is_collective, 61 is_gpu, 62 is_wait, 63 sympy_product, 64) 65from .virtualized import V 66 67 68log = logging.getLogger(__name__) 69fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") 70loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") 71 72 73@dataclasses.dataclass 74class SchedulerBuffer: 75 scheduler: Scheduler 76 node: ir.Buffer 77 defining_op: BaseSchedulerNode 78 users: List[NodeUser] = dataclasses.field(default_factory=list) 79 80 def __hash__(self) -> int: 81 return hash(self.node.name) 82 83 def debug_str(self) -> str: 84 result = IndentedBuffer() 85 name = self.get_name() 86 result.writeline(f"{name}: {type(self.node).__name__}") 87 result.writeline(f"{name}.layout = {self.node.layout}") 88 if self.get_aliases(): 89 result.writeline(f"{name}.aliases = {pformat(self.get_aliases())}") 90 if self.get_mutations(): 91 result.writeline(f"{name}.mutations = {pformat(self.get_mutations())}") 92 93 if len(self.users) <= 1: 94 result.writeline(f"{name}.users = {self.users}") 95 else: 96 result.writeline(f"{name}.users = [") 97 with result.indent(1): 98 for user in self.users: 99 result.writeline(f"{user},") 100 result.writeline("]") 101 return result.getrawvalue() 102 103 def get_name(self) -> str: 104 return self.node.get_name() 105 106 def allocate(self) -> None: 107 assert self.node is not None 108 if not self.node.should_allocate(): 109 return 110 111 if self.node.get_inputs_that_alias_output() or self.node.get_mutation_names(): 112 V.graph.wrapper_code.codegen_allocation(self.node) 113 return 114 115 # hacky check for if V.kernel is a real kernel or NullHandler 116 if ( 117 hasattr(V.kernel, "args") 118 and self.get_name() in V.kernel.inplace_update_buffers 119 ): 120 V.graph.wrapper_code.codegen_inplace_reuse( 121 self.scheduler.name_to_buf[ 122 V.kernel.inplace_update_buffers[self.get_name()] 123 ].node, 124 self.node, 125 ) 126 else: 127 V.graph.wrapper_code.codegen_allocation(self.node) 128 129 def can_free(self) -> bool: 130 # There's no real allocated buffer, no need to free it 131 assert self.node is not None 132 if isinstance(self.node.layout, ir.NoneLayout): 133 return False 134 for use in self.users: 135 if isinstance(use.node, OutputNode): 136 return False 137 return True 138 139 def set_users(self, users: List[NodeUser]) -> None: 140 # deduplicate 141 result: Dict[int, NodeUser] = {} 142 for use in users: 143 if id(use.node) in result: 144 result[id(use.node)] = use.merge(result[id(use.node)]) 145 else: 146 result[id(use.node)] = use 147 self.users = list(result.values()) 148 149 def get_aliases(self) -> Sequence[str]: 150 assert self.node is not None 151 return self.node.get_inputs_that_alias_output() 152 153 def get_mutations(self) -> List[str]: 154 assert self.node is not None 155 return self.node.get_mutation_names() 156 157 158class BaseSchedulerNode: 159 group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] 160 read_writes: dependencies.ReadWrites 161 unmet_dependencies: OrderedSet[Dep] 162 # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode. 163 # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node 164 # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3. 165 # For non-"grouped" nodes (i.e. regular SchedulerNode), 166 # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`. 167 min_order: int 168 max_order: int 169 170 def __init__(self, scheduler: Scheduler) -> None: 171 self.scheduler: Scheduler = scheduler 172 173 def _init_from_node(self, node: ir.Operation) -> None: 174 self.node: Optional[ir.Operation] = node 175 self.ancestors: OrderedSet[str] = OrderedSet() 176 self.last_usage: OrderedSet[ 177 str 178 ] = OrderedSet() # buffers that won't be used after this kernel 179 self.written = False 180 self.outputs: List[SchedulerBuffer] = [ 181 SchedulerBuffer( 182 scheduler=self.scheduler, 183 node=output, 184 defining_op=self, 185 ) 186 for output in node.get_outputs() 187 ] 188 self.outputs_by_name: Dict[str, SchedulerBuffer] = { 189 buf.get_name(): buf for buf in self.outputs 190 } 191 192 def __repr__(self) -> str: 193 return f"{type(self).__name__}(name={self.get_name()!r})" 194 195 def debug_str(self) -> str: 196 """Longer form printout for trace logs""" 197 name = self.get_name() 198 buf = IndentedBuffer() 199 buf.splice( 200 f"""\ 201{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__}) 202{name}.writes = {pformat(self.read_writes.writes)} 203{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} 204{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} 205{name}.outputs = [ 206 """ 207 ) 208 with buf.indent(): 209 for out in self.get_outputs(): 210 buf.splice(out.debug_str()) 211 buf.writeline("]") 212 213 try: 214 buf.splice(self.debug_str_extra()) 215 except Exception: 216 log.warning("Ignoring error in debug_str()", exc_info=True) 217 218 return buf.getrawvalue().rstrip() 219 220 def debug_str_extra(self) -> str: 221 return "" 222 223 def debug_str_short(self) -> str: 224 maybe_data = getattr(self.node, "data", None) 225 data_str = "" 226 if isinstance(maybe_data, torch._inductor.ir.Pointwise): 227 data_str = ", " + maybe_data.str_helper( 228 [maybe_data.get_size()], shorten=False, multiline=False 229 ) 230 elif isinstance(maybe_data, torch._inductor.ir.Reduction): 231 data_str = ", " + maybe_data.str_helper( 232 [maybe_data.get_reduction_size(), maybe_data.get_reduction_type()], 233 shorten=False, 234 multiline=False, 235 ) 236 return f"{self}{data_str}" 237 238 def log_details(self) -> None: 239 log.info( 240 "%s: unmet_dependencies = %s, writes = %s", 241 self, 242 self.unmet_dependencies, 243 self.read_writes.writes, 244 ) 245 246 def reorder_loops_by_dep_pair( 247 self, self_dep: MemoryDep, other_dep: MemoryDep 248 ) -> None: 249 return 250 251 def update_mutated_names(self, renames: Dict[str, str]) -> None: 252 self.set_read_writes(self.read_writes.rename(renames)) 253 254 def add_fake_dep(self, dep: Dep) -> None: 255 self.set_read_writes(self.read_writes.with_read(dep)) 256 257 def has_aliasing_or_mutation(self) -> bool: 258 return any( 259 buf.get_aliases() or buf.get_mutations() for buf in self.get_outputs() 260 ) 261 262 def set_read_writes(self, rw: dependencies.ReadWrites) -> None: 263 self.read_writes = rw 264 self.unmet_dependencies = self.read_writes.reads 265 self.prune_deps() 266 267 def set_last_usage( 268 self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str] 269 ) -> None: 270 used_buffers = self.used_or_aliased_buffer_names() 271 used_buffers = OrderedSet([mutation_real_name.get(k, k) for k in used_buffers]) 272 self.last_usage = used_buffers - future_used_buffers 273 274 def mark_run(self) -> None: 275 for buf in self.outputs: 276 buf.allocate() 277 278 def used_buffer_names(self) -> OrderedSet[str]: 279 return OrderedSet( 280 dep.name 281 for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) 282 ) 283 284 def used_or_aliased_buffer_names(self) -> OrderedSet[str]: 285 used_names: OrderedSet[str] = OrderedSet() 286 287 deps = [ 288 dep.name 289 for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) 290 ] 291 while len(deps) > 0: 292 dep = deps.pop() 293 used_names.add(dep) 294 if V.graph.name_to_buffer.get(dep): 295 for alias in V.graph.name_to_buffer[dep].get_inputs_that_alias_output(): 296 if alias not in used_names: 297 deps.append(alias) 298 return used_names 299 300 def prune_deps(self) -> None: 301 self.unmet_dependencies = OrderedSet( 302 dep 303 for dep in self.unmet_dependencies 304 if dep.name not in self.scheduler.available_buffer_names 305 ) 306 307 def prune_weak_deps(self) -> None: 308 # Prune weak dependencies on operations that have been removed 309 def should_prune(dep: Dep) -> bool: 310 if not isinstance(dep, WeakDep): 311 return False 312 op = self.scheduler.name_to_buf[dep.name].defining_op 313 return op.get_name() in V.graph.removed_operations 314 315 to_remove = OrderedSet( 316 dep for dep in self.read_writes.reads if should_prune(dep) 317 ) 318 self.set_read_writes(self.read_writes.remove_reads(to_remove)) 319 320 def prune_redundant_deps( 321 self, name_to_fused_node: Dict[str, BaseSchedulerNode] 322 ) -> None: 323 _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) 324 325 def get_name(self) -> str: 326 assert self.node is not None 327 return self.node.get_operation_name() 328 329 def get_first_name(self) -> str: 330 return self.get_name() 331 332 def get_operation_names(self) -> OrderedSet[str]: 333 return OrderedSet(node.get_name() for node in self.get_nodes()) 334 335 def get_buffer_names(self) -> OrderedSet[str]: 336 return OrderedSet(out.get_name() for out in self.outputs) 337 338 def get_nodes(self) -> Sequence[BaseSchedulerNode]: 339 return [self] 340 341 def get_outputs(self) -> Sequence[SchedulerBuffer]: 342 return self.outputs 343 344 def get_output(self, buf_name: str) -> SchedulerBuffer: 345 return self.outputs_by_name[buf_name] 346 347 def get_device(self) -> torch.device: 348 assert self.node is not None 349 return self.node.get_device() 350 351 def is_reduction(self) -> bool: 352 return False 353 354 def is_split_scan(self) -> bool: 355 return False 356 357 def is_template(self) -> bool: 358 return False 359 360 def is_extern(self) -> bool: 361 return False 362 363 def is_foreach(self) -> bool: 364 return False 365 366 def can_inplace(self, read_dep: dependencies.Dep) -> bool: 367 return False 368 369 def has_side_effects(self) -> bool: 370 return False 371 372 def decide_inplace_update(self) -> None: 373 """ 374 Decide if there should be inplace updates for the node 375 and record the decision in the active kernel. 376 """ 377 from .codegen.wrapper import buffer_reuse_key 378 379 if not ( 380 isinstance(self, (SchedulerNode,)) 381 and config.inplace_buffers 382 and V.graph.has_feature(self.get_device(), BackendFeature.INPLACE_BUFFERS) 383 and ( 384 not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel) 385 or getattr(V.kernel, "mutations", None) is not None 386 ) 387 # hacky check for if V.kernel is a real kernel or NullHandler 388 and hasattr(V.kernel, "args") 389 ): 390 return 391 392 ordered_reads = sorted(self.read_writes.reads, key=lambda x: x.name) 393 394 for buf in self.get_outputs(): 395 buf_node = buf.node 396 assert buf_node is not None 397 if ( 398 not buf_node.should_allocate() 399 or buf_node.get_inputs_that_alias_output() 400 or buf_node.get_mutation_names() 401 or buf.get_name() in V.graph.removed_buffers 402 ): 403 continue 404 405 for read in ordered_reads: 406 input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get( 407 read.name 408 ) 409 if ( 410 input_buf 411 and V.graph.wrapper_code.can_reuse(input_buf, self) 412 and not isinstance(input_buf.defining_op, NopKernelSchedulerNode) 413 ): 414 assert input_buf.users is not None 415 remaining_uses = [ 416 x 417 for x in input_buf.users 418 if x.node.get_name() not in self.scheduler.completed_operations 419 ] 420 if ( 421 len(remaining_uses) == 1 422 and remaining_uses[0].can_inplace 423 and remaining_uses[0].node is self 424 and input_buf.node is not None 425 and not isinstance( 426 input_buf.node.get_layout(), 427 ( 428 ir.MultiOutputLayout, 429 ir.MutationLayoutSHOULDREMOVE, 430 ), 431 ) 432 and not ( 433 isinstance( 434 input_buf.defining_op.node, 435 (ir.FallbackKernel, ir.MultiOutput), 436 ) 437 and len(input_buf.node.get_inputs_that_alias_output()) > 0 438 ) 439 and buffer_reuse_key(input_buf.node) 440 == buffer_reuse_key(buf.node) 441 ): 442 # if there isn't a triton kernel, then we don't need to call triton-specific things. 443 # but TODO this might be a convenient place to signal to the Collective kernels to inplace 444 # (and, can we make "kernel" less generic of a name?) 445 V.kernel.args.make_inplace(input_buf.get_name(), buf.get_name()) 446 # mutations not tracked in cpp kernels 447 if isinstance( 448 V.kernel, torch._inductor.codegen.simd.SIMDKernel 449 ): 450 V.kernel.mutations.add(input_buf.get_name()) 451 V.kernel.mutations.add(buf.get_name()) 452 453 # update last usage of reused node 454 self.last_usage.discard(input_buf.get_name()) 455 456 V.kernel.inplace_update_buffers[ 457 buf.get_name() 458 ] = input_buf.get_name() 459 break 460 461 def codegen_originating_info( 462 self, buffer: IndentedBuffer, only_once: bool = True 463 ) -> None: 464 if not config.comment_origin: 465 return 466 467 if only_once and self.written: 468 return 469 assert self.node is not None 470 origins = self.node.get_origins() 471 out_lines = [] 472 473 for o in origins: 474 if o.op == "output": 475 # These are boring and samey 476 continue 477 478 out_lines.append("") 479 # TODO(voz): Should the pragma be constant somewhere? 480 out_lines.append("#pragma CMT ORIGIN:") 481 op_info_str = f"#pragma CMT {o.op} {o.target}" 482 if "seq_nr" in o.meta: 483 op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}" 484 out_lines.append(op_info_str) 485 if "stack_trace" in o.meta: 486 stack_trace = f"{o.meta['stack_trace']}" 487 stack_trace_last_line = stack_trace.split("|")[-1] 488 out_lines.append( 489 "#pragma CMT " 490 + stack_trace_last_line.replace("{", "{{") 491 .replace("}", "}}") 492 .replace("\n", "\\") 493 ) 494 out_lines.append("#pragma CMT END ORIGIN") 495 out_lines.append("") 496 497 if len(out_lines) == 0: 498 return 499 500 # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does 501 # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. 502 buffer.writelines(out_lines) 503 self.written = True 504 505 def get_read_write_buffers_sizes(self) -> int: 506 """ 507 Counting the number of bytes accessed for a kernel is 508 surprisingly tricky. In particular, there is a differentiation 509 between 'theoretical' memory accesses and practical memory 510 accesses. For example, a layernorm kernel may actually access an 511 input 3 times, but in theory, it only needs to access its input 512 once (and may be optimized to do so through say, persistent 513 reductions) 514 515 Another example is that even though a buffer is passed in, we may 516 not access the entire buffer. This may occur if we are accessing 517 a slice of the buffer. Another tricky case is for indirect 518 indexing, where the amount of bytes accessed depends on the 519 values of the input. 520 521 What this function aims to compute is the memory accesses for 522 worst-case inputs, best-case optimization. What this means is 523 that for each buffer we compute the amount of potential accesses in two ways and take the minimum. 524 525 1. Numel in ranges multiplied by number of deps the buffer has 526 2. The buffer size 527 """ 528 if isinstance(self, NopKernelSchedulerNode): 529 return 0 530 if isinstance(self, ExternKernelSchedulerNode) and isinstance( 531 self.node, MultiOutput 532 ): 533 # todo: Calculate this - it's kinda annoying. 534 return 0 535 536 def try_size_hint(s: sympy.Expr) -> int: 537 return V.graph.sizevars.size_hint(s, fallback=0) 538 539 if isinstance(self, SchedulerNode): 540 node_numel = try_size_hint( 541 sympy_product(self.get_ranges()[0]) 542 * sympy_product(self.get_ranges()[1]), 543 ) 544 else: 545 node_numel = int(1e9) 546 buf_accesses = collections.defaultdict(list) 547 for dep in self.read_writes.reads | self.read_writes.writes: 548 buf_accesses[dep.name].append(dep) 549 550 reads = OrderedSet(dep.name for dep in self.read_writes.reads) 551 writes = OrderedSet(dep.name for dep in self.read_writes.writes) 552 553 def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool: 554 users = self.scheduler.name_to_buf[buf].users 555 buf_uses = OrderedSet(user.node for user in users) 556 return len(buf_uses - OrderedSet(snodes)) > 0 557 558 if isinstance(self, FusedSchedulerNode): 559 removed_buffers = OrderedSet( 560 dep for dep in writes if not is_materialized(dep, self.snodes) 561 ) 562 writes = writes - removed_buffers 563 reads = reads - removed_buffers 564 node_bytes = 0 565 566 for buf_name in reads | writes: 567 buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name]) 568 buf: Union[ir.Buffer, ir.TensorBox] 569 if buf_name in V.graph.name_to_buffer: 570 buf = V.graph.name_to_buffer[buf_name] 571 elif buf_name in V.graph.graph_inputs: 572 buf = V.graph.graph_inputs[buf_name] 573 else: 574 continue 575 576 def get_buf_bytes(buf: Optional[Union[ir.Buffer, ir.TensorBox]]) -> int: 577 if not buf: 578 return 0 579 # Kind of a lazy way to get the MultiOutput nodes corresponding to 580 # a MultiOutputLayout 581 if isinstance(buf.layout, MultiOutputLayout): 582 users = self.scheduler.name_to_buf[buf.get_name()].users 583 tot = 0 584 for user in users: 585 assert isinstance(user.node, BaseSchedulerNode) 586 if isinstance(user.node.node, MultiOutput): 587 for sched_buf in user.node.get_outputs(): 588 tot += get_buf_bytes(sched_buf.node) 589 else: 590 # Buf is a MultiOutputLayout but not all of its 591 # users are MultiOutputs... 592 # TODO: Figure out what's going on 593 return 0 594 return tot 595 elif isinstance(buf.layout, ir.NoneLayout): 596 return sum( 597 get_buf_bytes(V.graph.get_buffer(mut_name)) 598 for mut_name in buf.get_mutation_names() 599 ) 600 else: 601 buf_elems = try_size_hint(sympy_product(buf.get_size())) 602 return get_dtype_size(buf.get_dtype()) * min( 603 buf_accessed_elems, buf_elems 604 ) 605 606 node_bytes += get_buf_bytes(buf) 607 608 return node_bytes 609 610 def get_estimated_runtime(self) -> float: 611 """ 612 Returns estimated op runtime in nanoseconds (ns) 613 """ 614 buf = self.get_nodes()[0].get_outputs()[0] 615 layout = buf.node.get_layout() 616 dtype = buf.node.get_dtype() 617 618 if layout.device is not None and not is_gpu(layout.device.type): 619 # default to no reordering based on runtime 620 return 0 621 622 # Collective kernels 623 if is_collective(self.node): 624 assert isinstance(self.node, ir.IRNode) 625 try: 626 return estimate_nccl_collective_runtime(self.node) 627 except ValueError as e: 628 # We don't know how to estimate runtime for this collective, 629 # falling back to 0 630 log.info(e) 631 return 0 632 633 elif is_wait(self.node): 634 # ir.Wait is only used for collective ops. 635 # The time needed for the collective op is already estimated and considered 636 # when we are processing the collective op IR node, so ir.Wait takes 0 time 637 # since it doesn't take extra time to get the result after the collective is completed. 638 return 0 639 640 try: 641 gpu_memory_bandwidth = get_gpu_dram_gbps() 642 gpu_flops = get_device_tflops(dtype) * 10**12 643 except Exception: 644 return 0 645 646 if isinstance(self, ExternKernelSchedulerNode): 647 assert isinstance(self.node, ir.ExternKernel), f"{type(self.node)=}" 648 op = kernel_name_to_op.get( 649 getattr(self.node, "python_kernel_name", ""), None 650 ) 651 652 # if there is a resolved op, dry-run using fake mode and record flop count 653 if op is not None: 654 from torch._subclasses.fake_tensor import FakeTensorMode 655 from torch.utils.flop_counter import FlopCounterMode 656 657 if any( 658 len(free_unbacked_symbols(n.get_numel())) > 0 659 for n in self.node.inputs 660 ): 661 # Tensor has unbacked symints, we don't know how to estimate 662 # runtime for that today 663 return 0 664 665 with FakeTensorMode() as fake_mode, FlopCounterMode( 666 display=False 667 ) as flop_counter_mode, V.set_current_node( 668 self.node.fx_node 669 ), V.set_fake_mode( 670 fake_mode 671 ): 672 from .ir import ir_node_to_tensor 673 674 fake_inputs = [ 675 ir_node_to_tensor(input, guard_shape=False) 676 for input in self.node.inputs 677 ] 678 cls = self.node.__class__ 679 cls.process_kernel(op, *fake_inputs, **self.node.kwargs) 680 681 # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship 682 factor = 1.0 683 counted_flops = flop_counter_mode.get_total_flops() 684 counted_bytes = self.get_read_write_buffers_sizes() 685 compute_time = (factor * counted_flops / gpu_flops) * 1e9 686 transfer_time = counted_bytes / gpu_memory_bandwidth 687 688 # Return estimated runtime in nanoseconds 689 return max(compute_time, transfer_time) 690 691 elif isinstance(self, FusedSchedulerNode) or isinstance( 692 self.node, ComputedBuffer 693 ): 694 # Return estimated runtime in nanoseconds (bytes / gbps) 695 return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth 696 697 return 0 698 699 def get_template_node(self) -> Optional[ir.TemplateBuffer]: 700 return None 701 702 703class WhyNoFuse: 704 # TODO when we drop support for Python < 3.10, we can use 705 # @dataclass(slots=True) instead of manually specifying __slots__. 706 __slots__ = ["node1", "node2", "reason", "args"] 707 reason: str 708 args: Tuple[Any, ...] 709 710 def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None: 711 self.node1 = node1 712 self.node2 = node2 713 714 def __call__(self, reason: str, *args: Any) -> None: 715 self.reason = reason 716 self.args = args 717 fusion_log.debug(self) 718 719 def __str__(self) -> str: 720 return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + ( 721 self.reason % self.args 722 ) 723 724 725def pformat(obj: Any) -> str: 726 if isinstance(obj, OrderedSet): 727 # pformat has trouble with sets of sympy exprs 728 obj = sorted(obj, key=str) 729 result = pprint.pformat(obj, indent=4) 730 if "\n" in result: 731 return f"\n{textwrap.indent(result, ' ' * 4)}" 732 return result 733 734 735class OutputNode: 736 def __init__(self, dep: StarDep) -> None: 737 self.unmet_dependencies = OrderedSet([dep]) 738 739 def is_reduction(self) -> bool: 740 return False 741 742 def get_inputs_that_alias_output(self) -> Sequence[str]: 743 return () 744 745 def get_name(self) -> str: 746 return "OUTPUT" 747 748 __repr__ = get_name 749 750 751def _prune_redundant_deps( 752 node: BaseSchedulerNode, 753 name_to_fused_node: Dict[str, BaseSchedulerNode], 754 name_to_buf: Dict[str, SchedulerBuffer], 755) -> None: 756 """ 757 Prunes weakdeps intended for mutation ordering 758 on an upstream fused node if after fusion there is another dependency 759 on the fused upstream node, making the weakdep redundant 760 761 In essence this enforces an ordering on fusions. As fusions occur, weakdeps will 762 be incrementally removed, enabling other fusions, ensuring they are fused in order. 763 """ 764 name_to_dep_count: Counter[str] = collections.Counter() 765 766 for dep in node.unmet_dependencies: 767 if not isinstance(dep, WeakDep): 768 op = name_to_buf[dep.name].defining_op 769 name_to_dep_count[name_to_fused_node[op.get_name()].get_name()] += 1 770 771 def should_prune(dep: Dep) -> bool: 772 if isinstance(dep, WeakDep): 773 op_name = name_to_buf[dep.name].defining_op.get_name() 774 is_redundant = name_to_dep_count[name_to_fused_node[op_name].get_name()] > 0 775 # These can occur because fused nodes always gather deps from their snodes 776 # If B has a weakdep on A 777 # B gets fused with C, then any time BC is fused, the weakdep will reappear 778 is_self_dep = name_to_fused_node[op_name] == node 779 return is_redundant or is_self_dep 780 else: 781 return False 782 783 deps_to_prune = OrderedSet( 784 dep for dep in node.unmet_dependencies if should_prune(dep) 785 ) 786 787 if deps_to_prune: 788 node.unmet_dependencies = node.unmet_dependencies - deps_to_prune 789 node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) 790 791 792# TODO(xmfan): reuse: an existing mapping for this if it exists, or formalize this into ir.py:ExternKernel 793kernel_name_to_op = { 794 "extern_kernels.convolution": torch.ops.aten.convolution, 795 "extern_kernels.mm": torch.ops.aten.mm, 796 "extern_kernels.bmm": torch.ops.aten.bmm, 797 "extern_kernels.addmm": torch.ops.aten.addmm, 798} 799 800 801class ExternKernelSchedulerNode(BaseSchedulerNode): 802 def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: 803 super().__init__(scheduler) 804 self._init_from_node(node) 805 self.set_read_writes(node.get_read_writes()) 806 807 def debug_str_extra(self) -> str: 808 return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" 809 810 def is_extern(self) -> bool: 811 return True 812 813 def has_side_effects(self) -> bool: 814 assert self.node is not None 815 return hasattr(self.node, "has_side_effects") and self.node.has_side_effects() 816 817 818class NopKernelSchedulerNode(BaseSchedulerNode): 819 def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: 820 super().__init__(scheduler) 821 self._init_from_node(node) 822 self.set_read_writes(node.get_read_writes()) 823 824 825class SchedulerNode(BaseSchedulerNode): 826 def __init__( 827 self, 828 scheduler: Scheduler, 829 node: Union[ir.ComputedBuffer, ir.TemplateBuffer], 830 ) -> None: 831 super().__init__(scheduler) 832 self._init_from_node(node) 833 self._compute_attrs() 834 835 def _compute_attrs( 836 self, 837 extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, 838 recompute_sizes_body_func: Optional[Callable[..., Any]] = None, 839 ) -> None: 840 assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) 841 self._sizes, self._body = self.node.simplify_and_reorder( 842 extra_indexing_constraints=extra_indexing_constraints, 843 recompute_sizes_body_func=recompute_sizes_body_func, 844 ) 845 846 group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn 847 self.group = (self.node.get_device(), group_fn(self._sizes)) 848 849 # Don't normalize since normalization will merge loops which 850 # makes it hard to decide new loop orders. 851 should_normalize = ( 852 not config.loop_ordering_after_fusion 853 or self.node.get_device().type != "cuda" 854 ) 855 856 if isinstance(self.node, ir.TemplateBuffer): 857 self.set_read_writes( 858 self.node.extract_read_writes(normalize=should_normalize) 859 ) 860 else: 861 self.set_read_writes( 862 dependencies.extract_read_writes( 863 self._body, *self._sizes, normalize=should_normalize 864 ) 865 ) 866 867 def recompute_size_and_body( 868 self, 869 extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, 870 recompute_sizes_body_func: Optional[Callable[..., Any]] = None, 871 ) -> None: 872 self._compute_attrs( 873 extra_indexing_constraints=extra_indexing_constraints, 874 recompute_sizes_body_func=recompute_sizes_body_func, 875 ) 876 877 def refresh_dependencies(self, normalize: bool) -> None: 878 # Fake dependencies are added manually. They can not be analyzed from 879 # extract_read_writes. Find them out and apply manually. 880 fake_deps = { 881 dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep)) 882 } 883 884 # don't normalize since the loop order may need to be further changed 885 # later 886 self.set_read_writes( 887 dependencies.extract_read_writes( 888 self._body, *self._sizes, normalize=normalize 889 ).with_read(fake_deps) 890 ) 891 892 def apply_new_loop_order(self, new_order: Sequence[int]) -> None: 893 self._body = self._body.reorder_iter_loops( 894 new_order, 895 ) 896 self._sizes = self._body.sizes 897 898 self.refresh_dependencies(normalize=False) 899 900 def reorder_loops_by_dep_pair( 901 self, self_dep: MemoryDep, other_dep: MemoryDep 902 ) -> None: 903 new_order = None 904 self_sizes = self._sizes[0] 905 if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: 906 new_order = self_dep.decide_loop_order_to_match(other_dep) 907 908 if new_order: 909 metrics.num_loop_reordering += 1 910 loop_ordering_log.debug( 911 "Reorder loops for %s with order %s", self.get_name(), new_order 912 ) 913 self.apply_new_loop_order(new_order) 914 else: 915 loop_ordering_log.debug( 916 "Don't reordering %s because we can not decide the suitable loop order", 917 self.get_name(), 918 ) 919 920 def debug_str_extra(self) -> str: 921 name = self.get_name() 922 lines = [ 923 f"{name}.group.device = {self.group[0]}", 924 f"{name}.group.iteration = {self.group[1]}", 925 f"{name}.sizes = {self._sizes}", 926 ] 927 for dep in self.read_writes.reads_and_writes(): 928 if not isinstance(dep, WeakDep): 929 buf_name = dep.name 930 buf = V.graph.get_buffer(buf_name) 931 lines.append(f"{buf_name}_layout = {pformat(buf.layout)}") 932 if isinstance(self._body, LoopBody): 933 lines.append(f"class {name}_loop_body:") 934 lines.append(textwrap.indent(self._body.debug_str(), " ")) 935 936 assert self.node is not None 937 if ir.is_triton(self.node.get_device()): 938 lines.extend(debug_triton_code(self)) 939 940 return "\n".join(lines) 941 942 def get_ranges(self) -> Sequence[Sequence[sympy.Expr]]: 943 return self._sizes 944 945 def is_reduction(self) -> bool: 946 assert isinstance( 947 self.node, (ir.ComputedBuffer, ir.TemplateBuffer) 948 ), f"{type(self.node)=}" 949 return bool(self.node.get_reduction_type()) 950 951 def is_split_scan(self) -> bool: 952 assert isinstance( 953 self.node, (ir.ComputedBuffer, ir.TemplateBuffer) 954 ), f"{type(self.node)=}" 955 return isinstance(self.node, ir.ComputedBuffer) and isinstance( 956 self.node.data, ir.SplitScan 957 ) 958 959 def is_template(self) -> bool: 960 return isinstance(self.node, ir.TemplateBuffer) 961 962 def get_template_node(self) -> Optional[ir.TemplateBuffer]: 963 return self.node if isinstance(self.node, ir.TemplateBuffer) else None 964 965 def run(self, *index_vars: Sequence[sympy.Expr]) -> None: 966 self.decide_inplace_update() 967 self.mark_run() 968 self.codegen(index_vars) 969 970 def ranges_from_index_vars( 971 self, index_vars: Sequence[Sequence[sympy.Expr]] 972 ) -> Dict[sympy.Expr, sympy.Expr]: 973 sizes = self._sizes 974 assert sum(map(len, sizes)) == sum(map(len, index_vars)) 975 var_ranges = dict( 976 zip( 977 itertools.chain.from_iterable(index_vars), 978 itertools.chain.from_iterable(sizes), 979 ) 980 ) 981 return var_ranges 982 983 def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: 984 var_ranges = self.ranges_from_index_vars(index_vars) 985 try: 986 with V.set_ops_handler( 987 SimplifyIndexing(V.get_ops_handler(), var_ranges) 988 ), V.kernel.set_current_node(self): 989 self._body(*index_vars) 990 except Exception: 991 log.fatal("Error in codegen for %s", self.node) 992 raise 993 994 @cache_on_self 995 def pointwise_read_writes(self) -> dependencies.ReadWrites: 996 """ 997 Get the memory dependencies in the non-reduction axis. 998 """ 999 sizes, reduction_sizes = self._sizes 1000 return dependencies.extract_read_writes( 1001 self._body, sizes, hidden_args=[[sympy.Integer(0)] * len(reduction_sizes)] 1002 ) 1003 1004 def can_inplace(self, read_dep: dependencies.Dep) -> bool: 1005 if self.is_template(): 1006 return False 1007 if any(out.get_aliases() for out in self.get_outputs()): 1008 return False 1009 if len(self.read_writes.writes) == 1 and isinstance( 1010 read_dep, dependencies.MemoryDep 1011 ): 1012 write_dep = next(iter(self.read_writes.writes)) 1013 assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}" 1014 return read_dep.index == write_dep.index and read_dep.size == write_dep.size 1015 return False 1016 1017 @cache_on_self 1018 def _get_atomic_add_buffers(self) -> OrderedSet[str]: 1019 buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet() 1020 if isinstance(self._body, LoopBody): 1021 for node in self._body.get_nodes(): 1022 if ( 1023 node.op == "call_method" 1024 and node.target == "store" 1025 and ( 1026 ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add") 1027 or (len(node.args) == 5 and node.args[4] == "atomic_add") 1028 ) 1029 ): 1030 buffers_store_as_atomic_add.add( 1031 node.kwargs["name"] 1032 if "name" in node.kwargs 1033 else (node.args[1] if len(node.args) >= 2 else "") 1034 ) 1035 return buffers_store_as_atomic_add 1036 1037 1038def refresh_group_node_dependencies(group_snode: BaseSchedulerNode) -> None: 1039 snodes = group_snode.snodes # type: ignore[attr-defined] 1040 group_snode.set_read_writes( 1041 dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) 1042 ) 1043 1044 group_snode.unmet_dependencies = ( 1045 OrderedSet( 1046 dep 1047 for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes]) 1048 if dep.name not in group_snode.get_buffer_names() 1049 ) 1050 - group_snode.read_writes.writes 1051 ) 1052 1053 1054def init_group_node( 1055 group_snode: BaseSchedulerNode, 1056 scheduler: Scheduler, 1057 snodes: List[BaseSchedulerNode], 1058) -> None: 1059 assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode)) 1060 group_snode.snodes = snodes 1061 group_snode.scheduler = scheduler 1062 group_snode.node = None 1063 group_snode.ancestors = OrderedSet.union( 1064 *[x.ancestors for x in snodes if x.ancestors is not None] 1065 ) 1066 1067 refresh_group_node_dependencies(group_snode) 1068 1069 group_snode.min_order = min(x.min_order for x in group_snode.snodes) 1070 group_snode.max_order = max(x.max_order for x in group_snode.snodes) 1071 group_snode.outputs_by_name = { 1072 buf.get_name(): buf for buf in group_snode.get_outputs() 1073 } 1074 1075 1076class FusedSchedulerNode(BaseSchedulerNode): 1077 """ 1078 This is a "fake" scheduler node that represents a group of scheduler nodes 1079 that are meant to be fused together. The way it does this is by maintaining 1080 its unmet dependencies as the union of its constituent nodes. 1081 """ 1082 1083 snodes: List[BaseSchedulerNode] 1084 1085 @classmethod 1086 def fuse( 1087 cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode 1088 ) -> FusedSchedulerNode: 1089 assert node1.scheduler is node2.scheduler 1090 assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) 1091 assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) 1092 nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes())) 1093 return cls(node1.scheduler, nodes) 1094 1095 def reorder_loops_by_dep_pair( 1096 self, self_dep: MemoryDep, other_dep: MemoryDep 1097 ) -> None: 1098 if self.is_template(): 1099 # We can not really reorder loops for a triton template 1100 return 1101 self_sizes = None 1102 for snode in self.snodes: 1103 assert isinstance(snode, SchedulerNode) 1104 if self_sizes is not None and self_sizes != snode._sizes[0]: 1105 loop_ordering_log.debug( 1106 "Can not reorder fused node due to different sizes" 1107 ) 1108 return 1109 self_sizes = snode._sizes[0] 1110 new_order = None 1111 1112 assert self_sizes is not None 1113 if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: 1114 new_order = self_dep.decide_loop_order_to_match(other_dep) 1115 1116 if not new_order: 1117 loop_ordering_log.debug( 1118 "Dont reordering fused node %s because we can not decide the suitable loop order", 1119 self.get_name(), 1120 ) 1121 return 1122 metrics.num_loop_reordering += 1 1123 loop_ordering_log.debug( 1124 "Reorder loops for fused node %s with order %s", self.get_name(), new_order 1125 ) 1126 for snode in self.snodes: 1127 assert isinstance(snode, SchedulerNode) 1128 snode.apply_new_loop_order(new_order) # type: ignore[arg-type] 1129 1130 refresh_group_node_dependencies(self) 1131 1132 def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: 1133 super().__init__(scheduler) 1134 init_group_node(self, scheduler, snodes) 1135 self.users: List[NodeUser] = [] 1136 self.group = max(snodes, key=lambda x: int(x.is_reduction())).group 1137 1138 @cache_on_self 1139 def get_name(self) -> str: 1140 return "_".join([x.get_name() for x in self.snodes]) 1141 1142 def get_first_name(self) -> str: 1143 return self.snodes[0].get_name() 1144 1145 @cache_on_self 1146 def get_buffer_names(self) -> OrderedSet[str]: 1147 return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) 1148 1149 def get_outputs(self) -> List[SchedulerBuffer]: 1150 result: List[SchedulerBuffer] = [] 1151 for node in self.snodes: 1152 result.extend(node.get_outputs()) 1153 return result 1154 1155 def debug_str_extra(self) -> str: 1156 lines = [ 1157 f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" 1158 for i, node in enumerate(self.snodes) 1159 ] 1160 node = self.snodes[0].node 1161 if node is not None: 1162 device = node.get_device() 1163 if ir.is_triton(device): 1164 lines.extend(debug_triton_code(self)) 1165 1166 return textwrap.indent("\n".join(lines).rstrip(), " ") 1167 1168 def debug_str_short(self) -> str: 1169 snodes_str = [node.debug_str_short() for node in self.snodes] 1170 return f"{self}, snodes: {snodes_str}" 1171 1172 def set_last_usage( 1173 self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str] 1174 ) -> None: 1175 # Set self.last_usage using the global information 1176 # This will be used for inter-kernel optimisations 1177 super().set_last_usage(future_used_buffers, mutation_real_name) 1178 # Set self.last_usage on the snodes 1179 # This will be used for optimisations within the kernel 1180 future_used_buffers: OrderedSet[str] = OrderedSet() 1181 for node in reversed(self.snodes): 1182 node.set_last_usage(future_used_buffers, mutation_real_name) 1183 future_used_buffers.update(node.last_usage) 1184 1185 @cache_on_self 1186 def used_buffer_names(self) -> OrderedSet[str]: 1187 return OrderedSet.union(*[x.used_buffer_names() for x in self.snodes]) 1188 1189 @cache_on_self 1190 def used_or_aliased_buffer_names(self) -> OrderedSet[str]: 1191 return OrderedSet.union( 1192 *[x.used_or_aliased_buffer_names() for x in self.snodes] 1193 ) 1194 1195 def get_nodes(self) -> Sequence[BaseSchedulerNode]: 1196 return self.snodes 1197 1198 def __repr__(self) -> str: 1199 return f"{type(self).__name__}(nodes={self.get_name()})" 1200 1201 @cache_on_self 1202 def is_reduction(self) -> bool: 1203 return any(x.is_reduction() for x in self.snodes) 1204 1205 @cache_on_self 1206 def is_split_scan(self) -> bool: 1207 return any(x.is_split_scan() for x in self.snodes) 1208 1209 @cache_on_self 1210 def is_template(self) -> bool: 1211 return any(x.is_template() for x in self.snodes) 1212 1213 @cache_on_self 1214 def get_template_node(self) -> Optional[ir.TemplateBuffer]: 1215 for node in self.snodes: 1216 if node.is_template(): 1217 return node.get_template_node() 1218 return None 1219 1220 def get_device(self) -> torch.device: 1221 return self.group[0] 1222 1223 @cache_on_self 1224 def has_aliasing_or_mutation(self) -> bool: 1225 return any(x.has_aliasing_or_mutation() for x in self.snodes) 1226 1227 # None of these need to be implemented, as a FusedSchedulerNode is just an 1228 # abstraction for scheduling purposes 1229 def update_mutated_names(self, renames: Dict[str, str]) -> None: 1230 raise NotImplementedError 1231 1232 def add_fake_dep(self, name: Dep) -> None: 1233 raise NotImplementedError 1234 1235 def can_inplace(self, read_dep: dependencies.Dep) -> bool: 1236 raise NotImplementedError 1237 1238 def debug_str(self) -> str: 1239 """Longer form printout for trace logs""" 1240 name = self.get_name() 1241 node_typestr = ",".join(type(n).__name__ for n in self.snodes) 1242 buf = IndentedBuffer() 1243 buf.splice( 1244 f"""\ 1245{name}: {type(self).__name__}({node_typestr}) 1246{name}.writes = {pformat(self.read_writes.writes)} 1247{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} 1248{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} 1249{name}.outputs = [ 1250 """ 1251 ) 1252 with buf.indent(): 1253 for out in self.get_outputs(): 1254 buf.splice(out.debug_str()) 1255 buf.writeline("]") 1256 1257 try: 1258 buf.splice(self.debug_str_extra()) 1259 except Exception: 1260 log.warning("Ignoring error in debug_str()", exc_info=True) 1261 1262 return buf.getrawvalue().rstrip() 1263 1264 1265class ForeachKernelSchedulerNode(FusedSchedulerNode): 1266 """ 1267 This is a schedular node that consists of a set of scheduler nodes that 1268 has no data dependencies among them and can be executed in parallel. 1269 """ 1270 1271 def get_consumer_subnode_for( 1272 self, producer: BaseSchedulerNode 1273 ) -> Optional[BaseSchedulerNode]: 1274 for buf in producer.get_outputs(): 1275 if buf.get_name() in self.read_to_node: 1276 return self.read_to_node[buf.get_name()] 1277 1278 return None 1279 1280 def get_producer_subnode_for( 1281 self, consumer: BaseSchedulerNode 1282 ) -> Optional[BaseSchedulerNode]: 1283 producers = set() 1284 for rd in consumer.read_writes.reads: 1285 if rd.name not in self.scheduler.name_to_buf: 1286 continue 1287 1288 node_name = self.scheduler.name_to_buf[rd.name].defining_op.get_name() 1289 if node_name in self.name_to_node: 1290 producers.add(self.name_to_node[node_name]) 1291 1292 # Don't permit fusion if there are multiple subnodes 1293 # that this consumer reads from 1294 if len(producers) == 1: 1295 return next(iter(producers)) 1296 else: 1297 return None 1298 1299 @classmethod 1300 def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: 1301 why = WhyNoFuse(producer, consumer) 1302 if producer.is_foreach() and consumer.is_foreach(): 1303 producer = typing.cast(ForeachKernelSchedulerNode, producer) 1304 consumer = typing.cast(ForeachKernelSchedulerNode, consumer) 1305 foreach_match = len(producer.snodes) == len(consumer.snodes) 1306 if not foreach_match: 1307 why("foreach do not have same length") 1308 return foreach_match and all( 1309 producer.scheduler.can_fuse(l, r) 1310 for l, r in zip(producer.snodes, consumer.snodes) 1311 ) 1312 elif consumer.is_foreach(): 1313 if producer.is_reduction(): 1314 why( 1315 "candidate producer is a reduction, foreach ops cannot be fused with reductions currently" 1316 ) 1317 return False 1318 1319 consumer = typing.cast(ForeachKernelSchedulerNode, consumer) 1320 consumer_subnode = consumer.get_consumer_subnode_for(producer) 1321 if consumer_subnode is not None: 1322 return consumer.scheduler.can_fuse(producer, consumer_subnode) 1323 1324 why("candidate producer is not dep of any foreach consumer") 1325 return False 1326 1327 elif producer.is_foreach(): 1328 if consumer.is_reduction(): 1329 why( 1330 "candidate consumer is a reduction, foreach ops cannot be fused with reductions currently" 1331 ) 1332 return False 1333 1334 producer = typing.cast(ForeachKernelSchedulerNode, producer) 1335 producer_subnode = producer.get_producer_subnode_for(consumer) 1336 if producer_subnode is not None: 1337 return producer.scheduler.can_fuse(producer_subnode, consumer) 1338 1339 why("candidate consumer has no dep in any foreach producer") 1340 return False 1341 1342 raise AssertionError( 1343 "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node" 1344 ) 1345 1346 @classmethod 1347 def fuse( 1348 cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode 1349 ) -> ForeachKernelSchedulerNode: 1350 assert producer.is_foreach() or consumer.is_foreach() 1351 if producer.is_foreach(): 1352 producer = typing.cast(ForeachKernelSchedulerNode, producer) 1353 use_custom_partition_algo = producer.use_custom_partition_algo 1354 enable_autotune = producer.enable_autotune 1355 else: 1356 consumer = typing.cast(ForeachKernelSchedulerNode, consumer) 1357 use_custom_partition_algo = consumer.use_custom_partition_algo 1358 enable_autotune = consumer.enable_autotune 1359 prev_node_1 = None 1360 prev_node_2 = None 1361 fused_nodes: List[BaseSchedulerNode] 1362 if producer.is_foreach() and consumer.is_foreach(): 1363 producer = typing.cast(ForeachKernelSchedulerNode, producer) 1364 consumer = typing.cast(ForeachKernelSchedulerNode, consumer) 1365 fused_nodes = [ 1366 FusedSchedulerNode.fuse(l, r) 1367 for l, r in zip(producer.snodes, consumer.snodes) 1368 ] 1369 elif producer.is_foreach(): 1370 producer = typing.cast(ForeachKernelSchedulerNode, producer) 1371 producer_subnode = producer.get_producer_subnode_for(consumer) 1372 fused_nodes = [] 1373 prev_node_1 = producer 1374 prev_node_2 = None 1375 for node in producer.snodes: 1376 if node is producer_subnode: 1377 new_node = FusedSchedulerNode.fuse(node, consumer) 1378 prev_node_2 = new_node 1379 fused_nodes.append(new_node) 1380 else: 1381 fused_nodes.append(node) 1382 1383 elif consumer.is_foreach(): 1384 consumer = typing.cast(ForeachKernelSchedulerNode, consumer) 1385 consumer_subnode = consumer.get_consumer_subnode_for(producer) 1386 fused_nodes = [] 1387 prev_node_1 = consumer 1388 prev_node_2 = None 1389 1390 for node in consumer.snodes: 1391 if node is consumer_subnode: 1392 new_node = FusedSchedulerNode.fuse(producer, node) 1393 prev_node_2 = new_node 1394 fused_nodes.append(new_node) 1395 else: 1396 fused_nodes.append(node) 1397 else: 1398 raise AssertionError( 1399 "At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node" 1400 ) 1401 1402 return cls( 1403 producer.scheduler, 1404 fused_nodes, 1405 use_custom_partition_algo=use_custom_partition_algo, 1406 prev_node_1=prev_node_1, 1407 prev_node_2=prev_node_2, 1408 enable_autotune=enable_autotune, 1409 ) 1410 1411 def __init__( 1412 self, 1413 scheduler: Scheduler, 1414 snodes: List[BaseSchedulerNode], 1415 use_custom_partition_algo: bool, 1416 prev_node_1: Optional[BaseSchedulerNode] = None, 1417 prev_node_2: Optional[BaseSchedulerNode] = None, 1418 enable_autotune: bool = False, 1419 ) -> None: 1420 self.read_to_node = {} 1421 self.name_to_node = {} 1422 1423 if prev_node_1 is None or prev_node_2 is None: 1424 super().__init__(scheduler, snodes) 1425 1426 for node in snodes: 1427 for read in node.read_writes.reads: 1428 self.read_to_node[read.name] = node 1429 1430 for name in node.get_operation_names(): 1431 self.name_to_node[name] = node 1432 else: 1433 self.scheduler = scheduler 1434 self.snodes = snodes 1435 self.node = None 1436 self.users: List[NodeUser] = [] 1437 1438 self.set_read_writes( 1439 dependencies.ReadWrites.merge_list( 1440 [prev_node_1.read_writes, prev_node_2.read_writes] 1441 ) 1442 ) 1443 1444 self.unmet_dependencies = ( 1445 OrderedSet( 1446 dep 1447 for dep in OrderedSet.union( 1448 prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies 1449 ) 1450 if dep.name not in self.get_buffer_names() 1451 ) 1452 - self.read_writes.writes 1453 ) 1454 1455 self.min_order = min([prev_node_1.min_order, prev_node_2.min_order]) 1456 self.max_order = max([prev_node_1.max_order, prev_node_2.max_order]) 1457 1458 if prev_node_1.is_foreach(): 1459 assert isinstance(prev_node_1, ForeachKernelSchedulerNode) 1460 foreach_node, other_node = prev_node_1, prev_node_2 1461 else: 1462 assert isinstance(prev_node_2, ForeachKernelSchedulerNode) 1463 foreach_node, other_node = prev_node_2, prev_node_1 1464 1465 self.ancestors = foreach_node.ancestors 1466 self.ancestors.update(other_node.ancestors) 1467 1468 self.name_to_node = foreach_node.name_to_node 1469 for name in other_node.get_operation_names(): 1470 self.name_to_node[name] = other_node 1471 1472 self.use_custom_partition_algo = use_custom_partition_algo 1473 self.group = (snodes[0].get_device(), ((sympy.Expr("combo_kernel"),),)) 1474 self.origins: OrderedSet[torch.fx.Node] = OrderedSet() 1475 self.enable_autotune = enable_autotune 1476 1477 @classmethod 1478 def combinable_nodes( 1479 cls, nodes: List[BaseSchedulerNode] 1480 ) -> List[BaseSchedulerNode]: 1481 extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)] 1482 if extern: 1483 log.debug( 1484 "ComboKernels: %d external nodes are filtered %s", 1485 len(extern), 1486 [node.node.get_origins() for node in extern if node.node is not None], 1487 ) 1488 filtered_nodes = [ 1489 x 1490 for x in nodes 1491 if not isinstance(x, (NopKernelSchedulerNode, ExternKernelSchedulerNode)) 1492 ] 1493 foreach_nodes = [ 1494 x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode) 1495 ] 1496 if foreach_nodes: 1497 log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes)) 1498 filtered_nodes = [ 1499 x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode) 1500 ] 1501 template_nodes = [x for x in filtered_nodes if x.is_template()] 1502 if template_nodes: 1503 log.debug( 1504 "ComboKernels: %d template nodes are filtered", {len(template_nodes)} 1505 ) 1506 filtered_nodes = [x for x in filtered_nodes if x not in template_nodes] 1507 return filtered_nodes 1508 1509 @staticmethod 1510 def _default_group_nodes_for_combo_kernels( 1511 scheduler: Scheduler, 1512 ) -> List[List[BaseSchedulerNode]]: 1513 """ 1514 Returns a list of lists of nodes that are to be grouped together. 1515 """ 1516 sorted_nodes = scheduler._topological_sort_nodes() 1517 grouped_nodes = [] 1518 max_num_nodes = 8 1519 for nodes in sorted_nodes: 1520 grouped_nodes.extend( 1521 [ 1522 nodes[i : i + max_num_nodes] 1523 for i in range(0, len(nodes), max_num_nodes) 1524 ] 1525 ) 1526 1527 return grouped_nodes 1528 1529 group_algorithm_for_combo_kernels: Callable[ 1530 [Scheduler], List[List[BaseSchedulerNode]] 1531 ] = _default_group_nodes_for_combo_kernels 1532 1533 @staticmethod 1534 def set_group_algorithm_for_combo_kernels( 1535 custom_group_algorithm: Callable[[Scheduler], List[List[BaseSchedulerNode]]] 1536 ) -> None: 1537 ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = ( 1538 custom_group_algorithm 1539 ) 1540 1541 @staticmethod 1542 def group_nodes_for_combo_kernels( 1543 scheduler: Scheduler, 1544 ) -> List[List[BaseSchedulerNode]]: 1545 return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler) 1546 1547 def mark_run(self) -> None: 1548 raise NotImplementedError 1549 1550 def codegen(self) -> None: 1551 assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}" 1552 self.node.get_store_function()(self.node.make_loader()()) 1553 1554 def is_foreach(self) -> bool: 1555 return True 1556 1557 def get_subkernel_nodes(self) -> List[BaseSchedulerNode]: 1558 """Returns a list of nodes which comprise the combo kernel. 1559 These nodes may be vertically fused.""" 1560 return list(self.snodes) 1561 1562 def get_nodes(self) -> Sequence[BaseSchedulerNode]: 1563 """Returns all nodes contained in this kernel, unpacking fused nodes 1564 into their constituent scheduler nodes.""" 1565 return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes)) 1566 1567 def get_first_name(self) -> str: 1568 return self.snodes[0].get_first_name() 1569 1570 def prune_redundant_deps( 1571 self, name_to_fused_node: Dict[str, BaseSchedulerNode] 1572 ) -> None: 1573 _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) 1574 1575 for node in self.snodes: 1576 node.prune_redundant_deps(name_to_fused_node) 1577 1578 1579class GroupedSchedulerNode(BaseSchedulerNode): 1580 """ 1581 This is a "fake" scheduler node that represents a group of scheduler nodes 1582 that are meant to be *grouped* together (it does not allow another node to be scheduled 1583 in between its constituent nodes, nor does it allow another node to fuse into any of its constituent nodes). 1584 The way it does this is by maintaining its unmet dependencies as the union of its constituent nodes. 1585 Fusion will still happen among the nodes within each GroupedSchedulerNode. 1586 At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node. 1587 """ 1588 1589 snodes: List[BaseSchedulerNode] 1590 1591 @classmethod 1592 def create(cls, snodes: List[BaseSchedulerNode]) -> GroupedSchedulerNode: 1593 scheduler = snodes[0].scheduler 1594 assert all(node.scheduler is scheduler for node in snodes) 1595 grouped_snode = cls(scheduler, snodes) # type: ignore[arg-type] 1596 for snode in snodes: 1597 scheduler.name_to_fused_node[snode.get_name()] = grouped_snode 1598 scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode 1599 return grouped_snode 1600 1601 def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: 1602 super().__init__(scheduler) 1603 init_group_node(self, scheduler, snodes) 1604 1605 def unpack(self) -> List[BaseSchedulerNode]: 1606 """ 1607 Do fusion among nodes within this GroupedSchedulerNode, 1608 and then unpack this GroupedSchedulerNode into regular nodes. 1609 """ 1610 for snode in self.snodes: 1611 self.scheduler.name_to_fused_node[snode.get_name()] = snode 1612 del self.scheduler.name_to_fused_node[self.get_name()] 1613 return self.scheduler.fuse_nodes(self.snodes) 1614 1615 def add_fake_dep(self, fake_dep: Dep) -> None: 1616 self.set_read_writes(self.read_writes.with_read(fake_dep)) 1617 self.unmet_dependencies.add(fake_dep) 1618 1619 @cache_on_self 1620 def get_name(self) -> str: 1621 return "_".join([x.get_name() for x in self.snodes]) 1622 1623 def get_first_name(self) -> str: 1624 return self.snodes[0].get_name() 1625 1626 @cache_on_self 1627 def get_buffer_names(self) -> OrderedSet[str]: 1628 return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) 1629 1630 def get_outputs(self) -> List[SchedulerBuffer]: 1631 result: List[SchedulerBuffer] = [] 1632 for node in self.snodes: 1633 result.extend(node.get_outputs()) 1634 return result 1635 1636 def get_nodes(self) -> Sequence[BaseSchedulerNode]: 1637 return self.snodes 1638 1639 @classmethod 1640 def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: 1641 # GroupedSchedulerNode cannot be fused with another node 1642 return False 1643 1644 1645def pick_loop_order( 1646 stride_lengths: List[List[int]], 1647 sizes: List[sympy.Expr], 1648 priority_idx: Tuple[int, ...] = (), 1649) -> List[int]: 1650 """ 1651 A heuristic to decide loop iteration orders. This has not been well 1652 tuned and may be something we should autotune. 1653 """ 1654 1655 @functools.cmp_to_key 1656 def index_cmp(a: int, b: int) -> int: 1657 if sizes[a] == 1 or sizes[b] == 1: 1658 # 1-sizes don't matter, just move them to the end 1659 return cmp(sizes[a] == 1, sizes[b] == 1) 1660 1661 # Take abs, otherwise flipped dimensions are treated as smaller 1662 # strides than contiguous dims 1663 stride_len_a = [abs(sl[a]) for sl in stride_lengths] 1664 stride_len_b = [abs(sl[b]) for sl in stride_lengths] 1665 1666 # equivalent to 1667 # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all() 1668 a_first = sum( 1669 sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b) 1670 ) 1671 b_first = sum( 1672 sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b) 1673 ) 1674 if a_first > b_first: 1675 return -1 1676 if b_first > a_first: 1677 return 1 1678 1679 # otherwise contiguous 1680 return cmp(b, a) 1681 1682 order = list(reversed(range(len(stride_lengths[0])))) 1683 if len(priority_idx) > 0: 1684 # if we have priority node, only use that node's order 1685 stride_lengths = [stride_lengths[pi] for pi in priority_idx] 1686 if config.pick_loop_orders: 1687 order.sort(key=index_cmp) 1688 return order 1689 1690 1691@dataclasses.dataclass 1692class NodeUser: 1693 node: Union[BaseSchedulerNode, OutputNode] 1694 can_inplace: bool = False 1695 1696 # A weak user must be scheduled after a given node, but doesn't actually 1697 # use the result 1698 is_weak: bool = False 1699 1700 def __hash__(self) -> int: 1701 return hash((self.node.get_name(), self.can_inplace, self.is_weak)) 1702 1703 def __eq__(self, other: object) -> bool: 1704 return ( 1705 isinstance(other, NodeUser) 1706 and self.get_name() == other.get_name() 1707 and self.can_inplace == other.can_inplace 1708 and self.is_weak == other.is_weak 1709 ) 1710 1711 def get_name(self) -> str: 1712 return self.node.get_name() 1713 1714 def merge(self, other: NodeUser) -> NodeUser: 1715 assert self.node is other.node 1716 return NodeUser( 1717 self.node, 1718 self.can_inplace and other.can_inplace, 1719 self.is_weak and other.is_weak, 1720 ) 1721 1722 1723_post_grad_graph_counter = itertools.count() 1724 1725 1726class Scheduler: 1727 __dep_size_hint_cache: Dict[Dep, int] 1728 1729 def __init__(self, nodes: List[ir.Operation]) -> None: 1730 with dynamo_timed("Scheduler.__init__"): 1731 self._init(nodes) 1732 1733 def _init(self, nodes: List[ir.Operation]) -> None: 1734 super().__init__() 1735 self.__dep_size_hint_cache = {} 1736 V.graph.scheduler = self 1737 self.backends: Dict[torch.device, BaseScheduling] = {} 1738 self.post_grad_graph_id = next(_post_grad_graph_counter) 1739 1740 self.completed_operations: OrderedSet[str] = OrderedSet() 1741 self.available_buffer_names = OrderedSet( 1742 [ 1743 *V.graph.graph_inputs.keys(), 1744 *V.graph.constants.keys(), 1745 *V.graph.torchbind_constants.keys(), 1746 ] 1747 ) 1748 1749 self.nodes = [self.create_scheduler_node(n) for n in nodes] 1750 self.update_zero_dim_cpu_tensor() 1751 # some new constants could have been created above 1752 self.available_buffer_names.update(V.graph.constants.keys()) 1753 for node in self.nodes: 1754 node.prune_deps() 1755 1756 self.name_to_node: Dict[str, BaseSchedulerNode] = { 1757 n.get_name(): n for n in self.nodes 1758 } 1759 self.name_to_buf: Dict[str, SchedulerBuffer] = { 1760 buf.get_name(): buf for node in self.nodes for buf in node.get_outputs() 1761 } 1762 self.name_to_fused_node: Dict[str, BaseSchedulerNode] = self.name_to_node.copy() 1763 1764 # mutation_real_name: Maps back to the original name for codegen 1765 # Example: 1766 # If you mutate buf0 inside of buf1's kernel, then: 1767 # mutation_real_name = {"buf0" : "buf1"} 1768 # all subsequent uses of buf0 become buf1's usage in dependency graph 1769 self.mutation_real_name: Dict[str, str] = {} 1770 1771 # We handle mutation by renaming modified versions of the same 1772 # buffer in the dependency graph to prevent cycles. 1773 # mutation_renames: tracks the current name for a given buffer 1774 # (changed once per mutation) 1775 # Example: 1776 # If you mutate buf0 inside of buf1's kernel, then: 1777 # mutation_renames = {"buf1" : "buf0"} 1778 # in codegen we only use buf0, never buf1 1779 self.mutation_renames: Dict[str, str] = {} 1780 1781 self.compute_dependencies() 1782 self.nodes = self.topological_sort_schedule(self.nodes) 1783 self.dead_node_elimination() 1784 self.name_to_fused_node = {n.get_name(): n for n in self.nodes} 1785 self.compute_ancestors() 1786 if config.reorder_for_compute_comm_overlap: 1787 self.nodes = comms.decide_global_ordering_of_comms( 1788 self.nodes, 1789 self.name_to_buf, 1790 self.name_to_fused_node, 1791 ) 1792 1793 metrics.ir_nodes_pre_fusion += len(self.nodes) 1794 V.debug.ir_pre_fusion(self.nodes) 1795 self.num_orig_nodes = len(self.nodes) 1796 self.create_foreach_nodes() 1797 self.nodes = self.topological_sort_schedule(self.nodes) 1798 self.logged_slow_fusion: OrderedSet[Tuple[str, str]] = OrderedSet() 1799 if config._pre_fusion_custom_pass is not None: 1800 self.nodes = config._pre_fusion_custom_pass(self.nodes) 1801 self.nodes = self.fuse_nodes(self.nodes) 1802 self.merge_loops() 1803 self.finalize_multi_template_buffers() 1804 if config.reorder_for_compute_comm_overlap: 1805 self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) 1806 if config.combo_kernels: 1807 self.create_combo_kernel_nodes(num_ck_nodes=None) 1808 self.process_grouped_nodes() 1809 self.compute_last_usage() 1810 V.debug.ir_post_fusion(self.nodes) 1811 V.debug.graph_diagram(self.nodes) 1812 self.debug_draw_graph() 1813 1814 # used during codegen: 1815 self.current_device: Optional[torch.device] = None 1816 self.buffer_names_to_free: OrderedSet[str] = OrderedSet() 1817 1818 # fx graph node to the position it appears in the graph 1819 # for debug attribution 1820 self.origin_to_index: Dict[torch.fx.Node, int] = {} 1821 1822 get_metric_table("graph_stats").add_row( 1823 lambda: { 1824 "graph_id": self.post_grad_graph_id, 1825 "num_nodes_before_fusion": self.num_orig_nodes, 1826 "num_nodes_after_fusion": len(self.nodes), 1827 } 1828 ) 1829 1830 def get_current_device_or_throw(self) -> torch.device: 1831 if device := self.current_device: 1832 return device 1833 else: 1834 raise RuntimeError("No current device") 1835 1836 def debug_draw_graph(self) -> None: 1837 """Generate an image of the graph for debugging""" 1838 if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": 1839 from .debug import draw_buffers 1840 1841 draw_buffers(self.nodes, print_graph=True) 1842 1843 def debug_print_nodes(self, label: str) -> None: 1844 if log.isEnabledFor(logging.INFO): 1845 log.info("%s:", label) 1846 for node in self.nodes: 1847 node.log_details() 1848 1849 def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode: 1850 assert ( 1851 node.get_origins() is not None 1852 ), "All nodes passed to scheduling must have an origin" 1853 if node.is_no_op(): 1854 return NopKernelSchedulerNode(self, node) 1855 elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): 1856 return SchedulerNode(self, node) 1857 elif isinstance(node, ir.ExternKernel): 1858 return ExternKernelSchedulerNode(self, node) 1859 else: 1860 raise NotImplementedError(node) 1861 1862 def create_foreach_nodes(self) -> None: 1863 removed_node_names: OrderedSet[str] = OrderedSet() 1864 fe_nodes = [] 1865 kept_node_names = self.name_to_fused_node.keys() 1866 1867 for names in V.graph.lists.values(): 1868 names = [ 1869 name 1870 for name in names 1871 if name in kept_node_names 1872 and not isinstance(self.name_to_node[name], NopKernelSchedulerNode) 1873 ] 1874 if not names: 1875 # All nodes eliminated 1876 continue 1877 1878 removed_node_names.update(names) 1879 snodes = [self.name_to_node[name] for name in names] 1880 1881 enable_autotune = config.combo_kernels_autotune > 1 1882 fe_node = ForeachKernelSchedulerNode( 1883 self, 1884 snodes, 1885 use_custom_partition_algo=False, 1886 enable_autotune=enable_autotune, 1887 ) 1888 1889 fe_nodes.append(fe_node) 1890 1891 for name in names: 1892 self.name_to_fused_node[name] = fe_node 1893 1894 self.nodes = [ 1895 node for node in self.nodes if node.get_name() not in removed_node_names 1896 ] + list(fe_nodes) 1897 1898 def compute_dependencies(self) -> None: 1899 """ 1900 Create dependency edges between nodes, handling aliasing and 1901 mutation properly. 1902 """ 1903 1904 T = TypeVar("T") 1905 1906 class DedupList(Generic[T]): 1907 """ 1908 This data structure behaves like a list except it makes sure the 1909 elements remain unique. 1910 Normally one could use a OrderedSet/dict for this purpose however 1911 the list in question gets elements appended as it is being 1912 iterated over which means that we need to keep the list 1913 semantics. 1914 """ 1915 1916 def __init__( 1917 self, 1918 items: Optional[List[T]] = None, 1919 membership: Optional[OrderedSet[T]] = None, 1920 ) -> None: 1921 self.items = items or [] 1922 self.membership = membership or OrderedSet() 1923 1924 def append(self, node_user: T) -> None: 1925 if node_user in self.membership: 1926 return 1927 self.items.append(node_user) 1928 self.membership.add(node_user) 1929 1930 def __add__(self, other: DedupList[T]) -> DedupList[T]: 1931 new_membership = OrderedSet.union(self.membership, other.membership) 1932 new_items = self.items + [ 1933 x for x in other.items if x not in self.membership 1934 ] 1935 return DedupList(new_items, new_membership) 1936 1937 name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict( 1938 DedupList 1939 ) 1940 1941 # handle aliasing by using python aliasing in name_to_users 1942 # if foo aliases bar then we will make name_to_users["foo"] point 1943 # to the same python list as name_to_users["bar"] 1944 for node in self.nodes: 1945 for buf1 in node.get_outputs(): 1946 buf1_name = buf1.get_name() 1947 for buf2_name in buf1.get_aliases(): 1948 if buf1_name in name_to_users and buf2_name in name_to_users: 1949 # merge the two 1950 list1 = name_to_users[buf1_name] 1951 list2 = name_to_users[buf2_name] 1952 combined = list1 + list2 1953 for key in name_to_users.keys(): 1954 if ( 1955 name_to_users[key] is list1 1956 or name_to_users[key] is list2 1957 ): 1958 name_to_users[key] = combined 1959 elif buf1_name in name_to_users: 1960 name_to_users[buf2_name] = name_to_users[buf1_name] 1961 else: 1962 name_to_users[buf1_name] = name_to_users[buf2_name] 1963 1964 def rename(n: str) -> str: 1965 if n in self.mutation_renames: 1966 return rename(self.mutation_renames[n]) 1967 return n 1968 1969 def add_user( 1970 used_by_name: str, 1971 user_node: Union[BaseSchedulerNode, OutputNode], 1972 can_inplace: bool = False, 1973 is_weak: bool = False, 1974 ) -> None: 1975 name_to_users[rename(used_by_name)].append( 1976 NodeUser(user_node, can_inplace, is_weak) 1977 ) 1978 1979 unbacked_symbol_to_origin_node: Dict[sympy.Symbol, Optional[str]] = {} 1980 1981 # NB: None means that the dependency is on an input. Don't actually 1982 # generate a dependency because if we do, Inductor will start trying 1983 # to free the unbacked int but that's pointless 1984 for name, val in V.graph.graph_inputs.items(): 1985 if isinstance(val, sympy.Expr): 1986 for fs in val.free_symbols: 1987 unbacked_symbol_to_origin_node[fs] = None 1988 1989 for node in self.nodes: 1990 log.debug("scheduling %s", node.node) 1991 1992 # unbacked symbols don't follow ordinary buffer dependencies, so 1993 # we track their def/uses separately 1994 assert node.node is not None 1995 unbacked_symbol_defs = sorted( 1996 node.node.get_unbacked_symbol_defs(), key=lambda x: x.name 1997 ) 1998 for s in unbacked_symbol_defs: 1999 assert isinstance(s, sympy.Symbol) 2000 # Pick the first definer as canonical. There may be multiple 2001 # because if a MultiOutputLayout buffer propagates an unbacked 2002 # symint to multiple outputs, they will all claim to def it. 2003 if s not in unbacked_symbol_to_origin_node: 2004 unbacked_symbol_to_origin_node[s] = node.get_name() 2005 2006 unbacked_symbol_uses = sorted( 2007 node.node.get_unbacked_symbol_uses(), key=lambda x: x.name 2008 ) 2009 # if a kernel takes unbacked symints, register dependencies 2010 for s in unbacked_symbol_uses: 2011 assert ( 2012 s in unbacked_symbol_to_origin_node 2013 ), f"{s} not in {unbacked_symbol_to_origin_node}" 2014 if (r := unbacked_symbol_to_origin_node[s]) is not None: 2015 for buf in self.name_to_node[r].get_outputs(): 2016 node.add_fake_dep(StarDep(buf.get_name())) 2017 2018 if ( 2019 len(node.read_writes.writes) == 1 2020 and (dep := next(iter(node.read_writes.writes))) 2021 and isinstance(dep, MemoryDep) 2022 ): 2023 node_mode = dep.mode 2024 else: 2025 node_mode = None 2026 2027 # Handle output mutations 2028 for buf in node.get_outputs(): 2029 # a node will mutate either 0 or 1 buffers 2030 assert len(buf.get_mutations()) <= 1 2031 for alt_name in buf.get_mutations(): 2032 alt_name = rename(alt_name) 2033 # this node must run after the prior writer 2034 add_user(alt_name, node) 2035 node.add_fake_dep(StarDep(alt_name, mode=node_mode)) 2036 for user in name_to_users[alt_name].items: 2037 if user.get_name() == node.get_name(): 2038 continue 2039 2040 assert isinstance(user.node, BaseSchedulerNode) 2041 for other_name in user.node.get_buffer_names(): 2042 # this node must run after all prior readers 2043 other_name = rename(other_name) 2044 node.add_fake_dep( 2045 WeakDep(other_name, mutating_buf=buf.get_name()) 2046 ) 2047 add_user(other_name, node, is_weak=True) 2048 2049 # add normal non-mutation dependencies 2050 for read in node.read_writes.reads: 2051 if not isinstance(read, WeakDep): 2052 add_user(read.name, node, node.can_inplace(read)) 2053 2054 node.update_mutated_names(self.mutation_renames) 2055 2056 # update our renaming scheme for the next iteration 2057 for buf in node.get_outputs(): 2058 for alt_name in buf.get_mutations(): 2059 self.mutation_renames[rename(alt_name)] = buf.get_name() 2060 self.mutation_renames[alt_name] = buf.get_name() 2061 self.mutation_real_name[ 2062 buf.get_name() 2063 ] = self.mutation_real_name.get(alt_name, alt_name) 2064 2065 # make sure outputs aren't dead-code-eliminated 2066 for buf_name in V.graph.get_output_names(): 2067 log.debug("scheduling output %s", buf_name) 2068 add_user(buf_name, OutputNode(StarDep(buf_name))) 2069 2070 # make sure unbacked symints aren't dead-code-eliminated 2071 for out in V.graph.graph_outputs: 2072 for s in out.get_unbacked_symbol_uses(): 2073 assert ( 2074 s in unbacked_symbol_to_origin_node 2075 ), f"{s} not in {unbacked_symbol_to_origin_node.keys()}" 2076 if r := unbacked_symbol_to_origin_node[s]: 2077 for buf_name in self.name_to_node[r].get_buffer_names(): 2078 log.debug( 2079 "scheduling output %s for unbacked symint %s", buf_name, s 2080 ) 2081 add_user(buf_name, OutputNode(StarDep(buf_name))) 2082 2083 # make sure input mutation isn't dead-code-eliminated 2084 for name in self.mutation_renames: 2085 if name in V.graph.graph_inputs: 2086 add_user(name, OutputNode(StarDep(name))) 2087 V.graph.mutated_inputs.add(name) 2088 elif name in V.graph.constants: 2089 # In AOTI, module parameters and buffers are not lifted as graph inputs 2090 add_user(name, OutputNode(StarDep(name))) 2091 2092 inp_names = { 2093 name: index for index, name in enumerate(V.graph.graph_inputs.keys()) 2094 } 2095 V.graph.mutated_input_idxs = [ 2096 inp_names[name] for name in V.graph.mutated_inputs 2097 ] 2098 2099 # copy users information onto the nodes 2100 for node in self.nodes: 2101 for buf in node.get_outputs(): 2102 buf.set_users(name_to_users[buf.get_name()].items) 2103 2104 def dead_node_elimination(self) -> None: 2105 """ 2106 Remove any nodes without users 2107 """ 2108 # self.nodes is in topological order, so by iterating in reverse order 2109 # we have visited (and potentially removed) all users before visiting a 2110 # given node. 2111 updated_nodes = [] 2112 for node in reversed(self.nodes): 2113 2114 def can_eliminate_user(user: NodeUser) -> bool: 2115 return user.is_weak or user.get_name() in V.graph.removed_operations 2116 2117 active_buffers = False 2118 for buf in node.get_outputs(): 2119 can_eliminate = all(can_eliminate_user(u) for u in buf.users) 2120 if can_eliminate: 2121 log.debug("removed dead buffer: %s", buf.get_name()) 2122 V.graph.removed_buffers.add(buf.get_name()) 2123 else: 2124 active_buffers = True 2125 2126 can_eliminate = not node.has_side_effects() and not active_buffers 2127 2128 if not can_eliminate: 2129 updated_nodes.append(node) 2130 else: 2131 # dead code 2132 log.debug("removed dead operation: %s", node.get_name()) 2133 V.graph.removed_operations.add(node.get_name()) 2134 2135 self.nodes = list(reversed(updated_nodes)) 2136 2137 # Prune any WeakDeps no longer needed 2138 for node in self.nodes: 2139 node.prune_weak_deps() 2140 2141 def topological_sort_schedule( 2142 self, nodes: List[BaseSchedulerNode] 2143 ) -> List[BaseSchedulerNode]: 2144 """ 2145 Ensure nodes is in topologically sorted order 2146 """ 2147 seen: OrderedSet[BaseSchedulerNode] = OrderedSet() 2148 name_to_node: Dict[str, BaseSchedulerNode] = dict() 2149 result: List[BaseSchedulerNode] = [] 2150 2151 def visit(n: BaseSchedulerNode) -> None: 2152 if n not in seen: 2153 seen.add(n) 2154 for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): 2155 # We only care about doing toposort within `nodes` 2156 if dep.name not in name_to_node: 2157 continue 2158 visit(name_to_node[dep.name]) 2159 result.append(n) 2160 2161 for node in nodes: 2162 for name in node.get_buffer_names(): 2163 name_to_node[name] = node 2164 for node in nodes: 2165 visit(node) 2166 return result 2167 2168 def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]: 2169 unmet_deps = set() 2170 if isinstance( 2171 snode, 2172 ( 2173 SchedulerNode, 2174 ExternKernelSchedulerNode, 2175 NopKernelSchedulerNode, 2176 FusedSchedulerNode, 2177 ), 2178 ): 2179 for dep in snode.unmet_dependencies: 2180 unmet_deps.add(dep.name) 2181 else: 2182 raise RuntimeError( 2183 f"get_unmet_dep_nodes is not implemented for {type(snode)}." 2184 ) 2185 unmet_dep_ops = (self.name_to_buf[dep].defining_op for dep in unmet_deps) 2186 return list({self.name_to_fused_node[n.get_name()] for n in unmet_dep_ops}) 2187 2188 def _topological_sort_nodes(self) -> List[List[BaseSchedulerNode]]: 2189 """ 2190 Sort nodes by their topological order, return a list of node lists. 2191 """ 2192 order = [] 2193 nodes = dict.fromkeys(self.nodes, 0) 2194 children: Dict[Any, Any] = {} 2195 for node in self.nodes: 2196 deps = self._get_unmet_dep_nodes(node) 2197 nodes[node] = len(deps) 2198 for dep in deps: 2199 c = children.get(dep, []) 2200 c.append(node) 2201 children[dep] = c 2202 2203 zero_deg_nodes = [n for n, v in nodes.items() if v == 0] 2204 while zero_deg_nodes: 2205 order.append(zero_deg_nodes) 2206 for n in zero_deg_nodes: 2207 for user in children.get(n, []): 2208 nodes[user] -= 1 2209 nodes.pop(n) 2210 zero_deg_nodes = [n for n, v in nodes.items() if v == 0] 2211 assert not nodes, "Topological sort failed!" 2212 return order 2213 2214 def compute_ancestors(self) -> None: 2215 """ 2216 Populate each node.ancestors 2217 """ 2218 # note self.nodes is topologically sorted 2219 name_to_ancestors: Dict[str, OrderedSet[str]] = {} 2220 for node in self.nodes: 2221 ancestors: OrderedSet[str] = OrderedSet() 2222 for dep in node.unmet_dependencies: 2223 dep_node_name = self.name_to_buf[dep.name].defining_op.get_name() 2224 ancestors.add(dep_node_name) 2225 ancestors |= name_to_ancestors[dep_node_name] 2226 name_to_ancestors[node.get_name()] = ancestors 2227 node.ancestors = ancestors 2228 2229 for order, node in enumerate(self.nodes): 2230 node.min_order = order 2231 node.max_order = order 2232 2233 def merge_loops(self) -> None: 2234 for node in self.nodes: 2235 if not config.loop_ordering_after_fusion: 2236 continue 2237 2238 # Even for CPU, if we are using the halide backend, we still need 2239 # the merge loops steps below 2240 if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or ( 2241 node.get_device().type != "cuda" and config.cpu_backend != "halide" 2242 ): 2243 continue 2244 for snode in node.get_nodes(): 2245 # merge loops for the scheduler node 2246 if not isinstance(snode, SchedulerNode) or snode.is_template(): 2247 continue 2248 2249 snode._body = snode._body.merge_loops() 2250 snode._sizes = snode._body.sizes 2251 2252 # merge_loops is called after loop reordering. 2253 # We still need retain fake dependencies since codegen the 2254 # estimated amount of memory access rely on them. 2255 snode.refresh_dependencies(normalize=True) 2256 2257 # Note that for CPU backend, merging loops will change 2258 # snode.group. It's fine for Triton backend. 2259 # But if we simplify update snode.group like this: 2260 # group_fn = self.get_backend(snode.node.get_device()).group_fn 2261 # snode.group = (snode.node.get_device(), group_fn(snode._sizes)) 2262 # There is still an issue due to different snode in a 2263 # FusedSchedulerNode having different merged loops. 2264 # Skip CPU backend for now. 2265 2266 def fuse_nodes(self, nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: 2267 """ 2268 Combine eligible nodes into FusedSchedulerNodes. 2269 """ 2270 for i in range(10): 2271 old_len = len(nodes) 2272 fusion_log.debug( 2273 "===== attempting fusion (%d/10): %d nodes =====", 2274 i + 1, 2275 old_len, 2276 ) 2277 nodes = self.fuse_nodes_once(nodes) 2278 new_len = len(nodes) 2279 fusion_log.debug( 2280 "completed fusion round (%d/10): fused %d nodes into %d nodes\n", 2281 i + 1, 2282 old_len, 2283 new_len, 2284 ) 2285 if new_len == old_len or new_len == 1: 2286 fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1) 2287 break 2288 return nodes 2289 2290 def process_grouped_nodes(self) -> None: 2291 """ 2292 Unpack GroupedSchedulerNode into regular nodes. 2293 """ 2294 new_nodes: List[BaseSchedulerNode] = [] 2295 for node in self.nodes: 2296 new_nodes.extend( 2297 node.unpack() if isinstance(node, GroupedSchedulerNode) else [node] 2298 ) 2299 self.nodes = new_nodes 2300 2301 def benchmark_fused_nodes( 2302 self, nodes: Sequence[BaseSchedulerNode] 2303 ) -> Tuple[float, str]: 2304 """ 2305 Benchmark fused list of nodes and return the execution time 2306 in milliseconds on randomly generated inputs. 2307 """ 2308 assert len(nodes) > 0 2309 device = nodes[0].get_device() 2310 self.current_device = device 2311 backend = self.get_backend(device) 2312 return backend.benchmark_fused_nodes(nodes) 2313 2314 def finalize_multi_template_buffers(self) -> None: 2315 def replace_operation_buffer( 2316 orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer 2317 ) -> None: 2318 replaced_buf_name = new_node.get_name() 2319 orig_buf_name = orig_node.get_name() 2320 assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str) 2321 2322 replaced_op_name = new_node.get_operation_name() 2323 orig_op_name = orig_node.get_operation_name() 2324 assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str) 2325 2326 del V.graph.name_to_buffer[replaced_buf_name] 2327 new_node.name = orig_buf_name 2328 2329 del V.graph.name_to_op[replaced_op_name] 2330 new_node.operation_name = orig_op_name 2331 2332 orig = V.graph.buffers.index(orig_node) 2333 V.graph.buffers.remove(new_node) 2334 V.graph.buffers[orig] = new_node 2335 V.graph.name_to_buffer[orig_buf_name] = new_node 2336 2337 orig = V.graph.operations.index(orig_node) 2338 V.graph.operations.remove(new_node) 2339 V.graph.operations[orig] = new_node 2340 V.graph.name_to_op[orig_op_name] = new_node 2341 2342 for i, node in enumerate(self.nodes): 2343 if isinstance(node, SchedulerNode) and isinstance( 2344 node.node, ir.MultiTemplateBuffer 2345 ): 2346 multi_node = node.node 2347 min_node_unfused, _ = multi_node.get_min_choice() 2348 2349 if isinstance( 2350 min_node_unfused, 2351 torch._inductor.ir.TritonTemplateCallerBase, 2352 ): 2353 node.node.finalize_as_triton_caller(min_node_unfused) 2354 continue 2355 2356 out_tensorbox = min_node_unfused.output_node() 2357 out_storage = out_tensorbox.data 2358 assert isinstance(out_storage, ir.StorageBox) 2359 out_buffer = out_storage.data 2360 assert isinstance(out_buffer, ir.OperationBuffer) 2361 2362 out_buffer.layout = multi_node.layout 2363 replace_operation_buffer(multi_node, out_buffer) 2364 new_scheduler_node = self.create_scheduler_node(out_buffer) 2365 2366 self.nodes[i] = new_scheduler_node 2367 self.name_to_node[node.get_name()] = new_scheduler_node 2368 self.name_to_fused_node[node.get_name()] = new_scheduler_node 2369 2370 for new_out, old_out in zip( 2371 new_scheduler_node.get_outputs(), node.get_outputs() 2372 ): 2373 self.name_to_buf[old_out.get_name()] = new_out 2374 new_out.users = old_out.users 2375 2376 new_scheduler_node.min_order = node.min_order 2377 new_scheduler_node.max_order = node.max_order 2378 new_scheduler_node.last_usage = node.last_usage 2379 2380 def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool: 2381 return any( 2382 hasattr(n.node, "data") 2383 and n.node is not None 2384 and hasattr(n.node.data, "scatter_mode") 2385 and n.node.data.scatter_mode == "atomic_add" 2386 for n in node_list 2387 ) 2388 2389 def speedup_by_fusion( 2390 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 2391 ) -> bool: 2392 """ 2393 If config.benchmark_fusion is False, always return True. 2394 Otherwise, return True if fusion can brings speedup. 2395 """ 2396 2397 is_multi_template = node1.is_template() and isinstance( 2398 node1.get_template_node(), ir.MultiTemplateBuffer 2399 ) 2400 if not config.benchmark_fusion and not is_multi_template: 2401 return True 2402 2403 if ( 2404 node1.is_template() 2405 and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer) 2406 or node1.is_foreach() 2407 or node2.is_foreach() 2408 ): 2409 # TODO support benchmarking epilogue fusion 2410 return True 2411 2412 node_list_1 = node1.get_nodes() 2413 device = node_list_1[0].get_device() 2414 2415 # don't support benchmark fusion for CPU right now. 2416 if device.type == "cpu": 2417 return True 2418 2419 node_list_2 = node2.get_nodes() 2420 node_list_fused = list(itertools.chain(node_list_1, node_list_2)) 2421 2422 # We can not accurately benchmark kernel using atomic_add 2423 # due to how we generate random integer inputs. 2424 # Skip benchmarking them by allowing fusion. 2425 if self._any_atomic_add(node_list_fused): 2426 return True 2427 2428 from triton.compiler.errors import CompilationError 2429 2430 why = WhyNoFuse(node1, node2) 2431 2432 def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None: 2433 if fusion_log.isEnabledFor(logging.DEBUG): 2434 if ms_fused < ms1 + ms2: 2435 fusion_log.debug( 2436 "can fuse (benchmark): fusing %s with %s cause %sx speedup", 2437 node1.get_buffer_names(), 2438 node2.get_buffer_names(), 2439 green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), 2440 ) 2441 else: 2442 fusion_log.debug( 2443 "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown", 2444 node1.get_buffer_names(), 2445 node2.get_buffer_names(), 2446 red_text(f"{ms_fused / (ms1 + ms2):.3f}"), 2447 ) 2448 2449 if isinstance(node1, SchedulerNode) and isinstance( 2450 node1.node, ir.MultiTemplateBuffer 2451 ): 2452 multi_node = node1.node 2453 choice_timings = multi_node.choice_timings 2454 2455 _, ms1 = multi_node.get_min_choice() 2456 ms2, path2 = self.benchmark_fused_nodes(node_list_2) 2457 2458 min_ms_fused = float("inf") 2459 ms_fused_choice = None 2460 2461 triton_choices = 0 2462 2463 for choice, unfused_time in sorted( 2464 choice_timings.items(), key=lambda x: x[1] 2465 ): 2466 if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase): 2467 continue 2468 2469 if unfused_time >= ms1 + ms2: 2470 break 2471 2472 triton_choices += 1 2473 if triton_choices > config.max_epilogue_benchmarked_choices: 2474 break 2475 2476 # TODO - parallel compile triton templates 2477 # TODO - should prune/skip choices that are not within certain % of best choice 2478 with node1.node.swap_as_triton_caller(choice): 2479 ms_fused, _ = self.benchmark_fused_nodes(node_list_fused) 2480 2481 if ms_fused < min_ms_fused: 2482 min_ms_fused = ms_fused 2483 ms_fused_choice = choice 2484 2485 log_fusion(min_ms_fused, ms1, ms2) 2486 2487 # after we do a fusion, we finalize a triton template. 2488 # TODO - could preserve multi template and choices for subsequent fusions 2489 if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None: 2490 node1.node.finalize_as_triton_caller(ms_fused_choice) 2491 return True 2492 else: 2493 return False 2494 else: 2495 try: 2496 ms1, path1 = self.benchmark_fused_nodes(node_list_1) 2497 if math.isinf(ms1): 2498 why("register spilling of the first kernel") 2499 return False 2500 ms2, path2 = self.benchmark_fused_nodes(node_list_2) 2501 if math.isinf(ms2): 2502 why("register spilling of the second kernel") 2503 return False 2504 ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused) 2505 if math.isinf(ms_fused): 2506 why("register spilling of the fused kernel") 2507 return False 2508 except CompilationError as e: 2509 # workaround triton issue: https://github.com/openai/triton/issues/2151 2510 if "Loop-carried variable" in str(e): 2511 return True # allow fusion 2512 else: 2513 raise 2514 2515 log_fusion(ms_fused, ms1, ms2) 2516 if ( 2517 is_metric_table_enabled("slow_fusion") 2518 and ms_fused >= ms1 + ms2 2519 and (path1, path2) not in self.logged_slow_fusion 2520 ): 2521 self.logged_slow_fusion.add((path1, path2)) 2522 get_metric_table("slow_fusion").add_row( 2523 lambda: { 2524 "kernel1_path": path1, 2525 "kernel1_latency": ms1, 2526 "kernel2_path": path2, 2527 "kernel2_latency": ms2, 2528 "fused_kernel_path": path_fused, 2529 "fused_kernel_latency": ms_fused, 2530 "slow_down_ratio": ms_fused / (ms1 + ms2), 2531 } 2532 ) 2533 return ms_fused < ms1 + ms2 2534 2535 def fuse_nodes_once( 2536 self, nodes: List[BaseSchedulerNode] 2537 ) -> List[BaseSchedulerNode]: 2538 """ 2539 Combine eligible nodes into FusedSchedulerNodes. 2540 2541 This relies on two key functions to control the logic: 2542 - self.can_fuse(): checks if a fusion is legal 2543 - self.score_fusion(): assigns priority to a given fusion 2544 """ 2545 fused_nodes = OrderedSet(nodes) 2546 if fusion_log.isEnabledFor(logging.DEBUG): 2547 fusion_log.debug("fuse_nodes_once, candidates:") 2548 for node in fused_nodes: 2549 fusion_log.debug(" " + node.debug_str_short()) # noqa: G003 2550 for node1, node2 in self.get_possible_fusions(nodes): 2551 node1 = self.name_to_fused_node[node1.get_first_name()] 2552 node2 = self.name_to_fused_node[node2.get_first_name()] 2553 if self.can_fuse(node1, node2) and not self.will_fusion_create_cycle( 2554 node1, node2 2555 ): 2556 if not self.speedup_by_fusion(node1, node2): 2557 continue 2558 fusion_log.debug( 2559 "fusing %s with %s", node1.get_name(), node2.get_name() 2560 ) 2561 2562 # above can_fuse asserts that node2 has the same device 2563 device = node1.get_device() 2564 node3 = self.get_backend(device).fuse(node1, node2) 2565 fused_nodes.remove(node1) 2566 fused_nodes.remove(node2) 2567 fused_nodes.add(node3) 2568 self.name_to_fused_node.update( 2569 {n.get_name(): node3 for n in node3.get_nodes()} 2570 ) 2571 nodes = sorted(fused_nodes, key=lambda x: x.min_order) 2572 nodes = self.topological_sort_schedule(nodes) 2573 self.prune_redundant_deps(nodes) 2574 return nodes 2575 2576 def create_combo_kernel_nodes(self, num_ck_nodes: Optional[int] = None) -> None: 2577 """ 2578 Groups parallel nodes 2579 """ 2580 fused_nodes = set(self.nodes) 2581 count = 0 2582 num_nodes_orig = len(self.nodes) 2583 log.debug("ComboKernels: Generating with num_ck_nodes = %d...", num_ck_nodes) 2584 for num, node_list in enumerate( 2585 ForeachKernelSchedulerNode.group_nodes_for_combo_kernels(self) 2586 ): 2587 node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list) 2588 if len(node_list) < 2: 2589 continue 2590 if num_ck_nodes is not None and count > num_ck_nodes: 2591 break 2592 if not self.speedup_by_combo_kernel(node_list): 2593 log.debug("ComboKernels: Not speeding up %d-th group", num) 2594 continue 2595 count += 1 2596 enable_autotune = config.combo_kernels_autotune > 0 2597 group_snode = ForeachKernelSchedulerNode( 2598 node_list[0].scheduler, 2599 node_list, 2600 use_custom_partition_algo=True, 2601 enable_autotune=enable_autotune, 2602 ) 2603 log.info( 2604 "ComboKernels: Combining %d nodes for %d-th group", 2605 len(node_list), 2606 num, 2607 ) 2608 for node in node_list: 2609 fused_nodes.remove(node) 2610 fused_nodes.add(group_snode) 2611 self.name_to_fused_node.update( 2612 {n.get_name(): group_snode for n in group_snode.get_nodes()} 2613 ) 2614 self.nodes = sorted(fused_nodes, key=lambda x: x.min_order) 2615 self.nodes = self.topological_sort_schedule(self.nodes) 2616 log.info( 2617 "Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodels", 2618 count, 2619 num_nodes_orig, 2620 len(self.nodes), 2621 ) 2622 self.prune_redundant_deps(self.nodes) 2623 2624 def prune_redundant_deps(self, nodes: List[BaseSchedulerNode]) -> None: 2625 for node in nodes: 2626 node.prune_redundant_deps(self.name_to_fused_node) 2627 2628 def get_possible_fusions( 2629 self, nodes: List[BaseSchedulerNode] 2630 ) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]: 2631 """ 2632 Helper to find all legal fusion opportunities, sorted by self.score_fusion() 2633 """ 2634 possible_fusions = [] 2635 seen: OrderedSet[Tuple[BaseSchedulerNode, BaseSchedulerNode]] = OrderedSet() 2636 2637 def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None: 2638 for node1_index, node1 in enumerate(nodes): 2639 for node2 in nodes[node1_index + 1 :]: 2640 key = (node1, node2) 2641 if key in seen: 2642 continue 2643 seen.add(key) 2644 2645 if self.can_fuse(node1, node2): 2646 possible_fusions.append(key) 2647 elif (node2.is_template() or node2.is_foreach()) and self.can_fuse( 2648 node2, node1 2649 ): 2650 # foreach fusions and epilogue fusions are order dependent 2651 possible_fusions.append((node2, node1)) 2652 2653 buffer_names_grouping = collections.defaultdict(list) 2654 for node in nodes: 2655 for buf in node.used_buffer_names(): 2656 buffer_names_grouping[buf].append(node) 2657 for node_grouping in buffer_names_grouping.values(): 2658 check_all_pairs(node_grouping) 2659 2660 if config.aggressive_fusion: 2661 group_grouping = collections.defaultdict(list) 2662 for node in nodes: 2663 group = getattr(node, "group", None) 2664 if group: 2665 group_grouping[group].append(node) 2666 for node_grouping in group_grouping.values(): 2667 check_all_pairs(node_grouping) 2668 2669 possible_fusions = self.get_possible_fusions_with_highest_priority( 2670 possible_fusions 2671 ) 2672 possible_fusions.sort(key=self.score_fusion_key, reverse=True) 2673 fusion_log.debug("found %d possible fusions", len(possible_fusions)) 2674 return possible_fusions 2675 2676 def will_fusion_create_cycle( 2677 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 2678 ) -> bool: 2679 """ 2680 Finds whether there's a path from node1 to node2 (or vice-versa) 2681 caused indirectly by other fusions. 2682 """ 2683 # since we are just returning boolean here, use slightly faster, unordered set 2684 visited: Set[FusedSchedulerNode] = set() 2685 2686 def found_path(node: BaseSchedulerNode) -> bool: 2687 # only fused nodes can introduce new ancestors. 2688 if isinstance(node, FusedSchedulerNode) and node not in visited: 2689 visited.add(node) 2690 if node.get_operation_names().issubset(combined_ancestors): 2691 # All fusion outputs are in ancestors of node1 and node2, thus 2692 # cannot introduce new path: 2693 # 2694 # 1. if output is neither descendent of node1 or node2, the 2695 # output cannot introduce a path 2696 # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be 2697 # on path(node1->node2), hence it cannot be ancestor of node2 2698 # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be 2699 # ancestor of node1 2700 return False 2701 else: 2702 # continue DFS of new ancestors introduced by the fusion 2703 return bool(combined_names & node.ancestors) or any( 2704 found_path(self.name_to_fused_node[n]) 2705 for n in node.ancestors - combined_ancestors 2706 ) 2707 return False 2708 2709 # as above - use slightly faster, unordered set 2710 combined_names = ( 2711 node1.get_operation_names()._dict.keys() 2712 | node2.get_operation_names()._dict.keys() 2713 ) 2714 combined_ancestors = ( 2715 node1.ancestors._dict.keys() | node2.ancestors._dict.keys() 2716 ) - combined_names 2717 cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors) 2718 if cycle: 2719 WhyNoFuse(node1, node2)("will create cycle") 2720 return cycle 2721 2722 def can_fusion_increase_peak_memory( 2723 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 2724 ) -> bool: 2725 """ 2726 This function prevents fusion for nodes that can increase memory 2727 footprint. This problem is more common in horizontal fusion, where nodes 2728 that are far apart in the original order get fused, lengthening the live 2729 intervals of tensors. This is very evident in models with activation 2730 checkpointing, where the recomputed nodes from different checkpointed 2731 regions get fused and significantly increase the memory footprint. 2732 2733 The current attempt is a quick, possibly hacky, heuristic to prevent the 2734 fusion of nodes that are far away in the original order. 2735 2736 A better but difficult to implement heurisitic would be to use live 2737 intervals of the buffers, find region of peak pressure in the original 2738 program and prevent fusion that crosses that peak region. We might need 2739 special care or good approximation in this implementation, as fusion of 2740 node changes live intervals, and re-computing live intervals and peak 2741 memory after each fusion can introduce large compilation overhead. 2742 """ 2743 proximity_score = max( 2744 abs(node1.min_order - node2.max_order), 2745 abs(node2.min_order - node1.max_order), 2746 ) 2747 return proximity_score > 64 2748 2749 def decide_fusion_fail_reason( 2750 self, 2751 node1: BaseSchedulerNode, 2752 node2: BaseSchedulerNode, 2753 common_buf_names: Tuple[str, ...], 2754 ) -> str: 2755 """ 2756 Try to decide reasons why fusion fail due to no shared memory even though 2757 there are common buffers. 2758 """ 2759 reasons = {} 2760 node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} 2761 node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} 2762 2763 for buf_name in common_buf_names: 2764 buf = V.graph.get_buffer(buf_name) 2765 lhs_dep = node1_name2dep[buf_name] 2766 rhs_dep = node2_name2dep[buf_name] 2767 2768 if lhs_dep.get_numel() != rhs_dep.get_numel(): 2769 reasons[ 2770 buf_name 2771 ] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}" 2772 continue 2773 2774 # same numel but different MemoryDep.size. Should be broadcasting 2775 if sympy_product(lhs_dep.size) != sympy_product(rhs_dep.size): 2776 reasons[buf_name] = "broadcast" 2777 continue 2778 2779 if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): 2780 reasons[ 2781 buf_name 2782 ] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}" 2783 continue 2784 2785 lhs_off = lhs_dep.get_offset() 2786 rhs_off = rhs_dep.get_offset() 2787 if lhs_off != rhs_off: 2788 # One example is in transformer, we use a concatenated linear layer 2789 # to project Q/K/V and then split the result. The 3 splits will 2790 # point to the same buffer with different offsets. 2791 reasons[buf_name] = f"different offset: {lhs_off} v.s. {rhs_off}" 2792 continue 2793 2794 if ( 2795 lhs_dep.normalize_with_stride_order() 2796 == rhs_dep.normalize_with_stride_order() 2797 ): 2798 reasons[buf_name] = f"Mismatch loop orders: {lhs_dep} v.s. {rhs_dep}" 2799 continue 2800 2801 # Add more rules here 2802 reasons[ 2803 buf_name 2804 ] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. Layout: {buf.layout}" 2805 2806 return str(reasons) 2807 2808 def has_shared_data_after_reordering_loop( 2809 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 2810 ) -> bool: 2811 """ 2812 Right now just greedily reorder the loop of node1 to be compatible with node2, 2813 but ideally we should have some heuristics to reorder the loop for node2 2814 to be compatibile with node1 if that's more efficient. 2815 """ 2816 2817 # TODO Don't do loop reordering for CPU for now. 2818 # Should debug more why it does not work for CPU codegen 2819 if not config.loop_ordering_after_fusion or any( 2820 n.get_device().type == "cpu" for n in [node1, node2] 2821 ): 2822 return False 2823 2824 node1_buffer_names = node1.read_writes.buffer_names() 2825 node2_buffer_names = node2.read_writes.buffer_names() 2826 # Fast path: no common buffers. 2827 common_buffer_names = node1_buffer_names & node2_buffer_names 2828 if not common_buffer_names: 2829 return False 2830 2831 node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} 2832 node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} 2833 2834 # Find the commons buffers that has different loop orders 2835 candidates = [] 2836 for buffer_name in common_buffer_names: 2837 lhs_dep = node1_name2dep[buffer_name] 2838 rhs_dep = node2_name2dep[buffer_name] 2839 if ( 2840 lhs_dep.normalize_with_stride_order() 2841 == rhs_dep.normalize_with_stride_order() 2842 ): 2843 candidates.append( 2844 ( 2845 V.graph.sizevars.size_hint(lhs_dep.get_numel(), fallback=0), 2846 lhs_dep, 2847 rhs_dep, 2848 ) 2849 ) 2850 2851 if len(candidates) == 0: 2852 return False 2853 2854 # Pick the largest buffer to guide the loop reordering 2855 numel, lhs_dep, rhs_dep = sorted(candidates, reverse=True, key=lambda x: x[0])[ 2856 0 2857 ] 2858 2859 if lhs_dep.num_vars != rhs_dep.num_vars: 2860 # this can happen due to we don't merge loops. 2861 # We can not do loop reordering in this case right now 2862 # Simply returning true if the two Deps are the same after 2863 # normalization (merging loops) 2864 return lhs_dep.normalize() == rhs_dep.normalize() 2865 2866 # Only reorder loops for pointwise for now 2867 if not node1.is_reduction(): 2868 node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep) 2869 elif not node2.is_reduction(): 2870 node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep) 2871 else: 2872 loop_ordering_log.debug( 2873 "Don't reorder loops since both nodes are reductions: %s v.s. %s", 2874 node1.get_name(), 2875 node2.get_name(), 2876 ) 2877 2878 return self.score_fusion_memory(node1, node2) > 0 2879 2880 def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: 2881 """ 2882 Determine if it is possible to combine node1 and node2 into a 2883 single fused node. 2884 """ 2885 2886 if node1 is node2: 2887 return False 2888 2889 why = WhyNoFuse(node1, node2) 2890 2891 if isinstance(node1, GroupedSchedulerNode) or isinstance( 2892 node2, GroupedSchedulerNode 2893 ): 2894 why("grouped node must not be fused with other nodes") 2895 return False 2896 if ( 2897 isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) 2898 and not node1.is_template() 2899 ): 2900 why("node1 is extern or nop") 2901 return False 2902 if ( 2903 isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) 2904 and not node2.is_template() 2905 ): 2906 why("node2 is extern or nop") 2907 return False 2908 2909 if node2.get_operation_names() & node1.ancestors: 2910 why("node1 must go before node2") 2911 return False 2912 2913 if node2.is_template(): 2914 why("templates can only fuse epilogues") 2915 return False 2916 if node1.is_template() and ( 2917 node2.has_aliasing_or_mutation() 2918 or node2.is_reduction() 2919 or not config.epilogue_fusion 2920 ): 2921 why("template epilogue not satisfied") 2922 return False 2923 2924 if ( 2925 node1.get_buffer_names() | node2.get_buffer_names() 2926 ) & V.graph.no_fuse_buffer_names: 2927 why("fusion for buffer explicit disabled") 2928 return False 2929 2930 device = node1.get_device() 2931 device2 = node2.get_device() 2932 if device != device2: 2933 why("device mismatch (%s vs %s)", device, device2) 2934 return False 2935 del device2 2936 2937 no_shared_data = self.score_fusion_memory(node1, node2) == 0 2938 if no_shared_data: 2939 no_shared_data = not self.has_shared_data_after_reordering_loop( 2940 node1, node2 2941 ) 2942 2943 loop_ordering_log.debug( 2944 "%s and %s has%s shared data", 2945 node1.get_name(), 2946 node2.get_name(), 2947 " no" if no_shared_data else "", 2948 ) 2949 if no_shared_data and ( 2950 not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() 2951 ): 2952 if is_metric_table_enabled("fusion_failure_due_to_indexing_mismatch"): 2953 common_buf_names = ( 2954 node1.read_writes.buffer_names() & node2.read_writes.buffer_names() 2955 ) 2956 if len(common_buf_names) > 0: 2957 get_metric_table("fusion_failure_due_to_indexing_mismatch").add_row( 2958 lambda: { 2959 "pre_grad_graph_id": V.graph.graph_id, 2960 "post_grad_graph_id": V.graph.post_grad_graph_id, 2961 "node1_name": node1.get_name(), 2962 "node2_name": node2.get_name(), 2963 "node1_debug_str": write_text(node1.debug_str()), 2964 "node2_debug_str": write_text(node2.debug_str()), 2965 "common_buffer_names": list(common_buf_names), 2966 "failure_reason": self.decide_fusion_fail_reason( 2967 node1, node2, common_buf_names 2968 ), 2969 } 2970 ) 2971 2972 why("no shared data due to indexing mismatch") 2973 return False 2974 why("no shared data") 2975 return False # heuristic not needed for correctness 2976 2977 if ( 2978 not node1.is_foreach() 2979 and not node2.is_foreach() 2980 and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size 2981 ): 2982 why("exceeds max fusion") 2983 return False # heuristic not needed for correctness 2984 2985 if node1.get_operation_names() & node2.ancestors: 2986 # node2 depends on node1 outputs 2987 if not self.can_fuse_vertical(node1, node2): 2988 return False 2989 return self.get_backend(device).can_fuse_vertical(node1, node2) 2990 else: # nodes don't depend on each other, but may have common reads 2991 if self.can_fusion_increase_peak_memory(node1, node2): 2992 why("will increase peak memory") 2993 return False 2994 return self.get_backend(device).can_fuse_horizontal(node1, node2) 2995 2996 def can_fuse_vertical( 2997 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 2998 ) -> bool: 2999 """ 3000 Check if it is legal to fuse a consumer (node2) into a producer (node1). 3001 3002 We can fuse them if all the reads of node2 either match 3003 corresponding writes in node1, or are written by nodes that can 3004 be scheduled before the fusion of node1 and node2. 3005 """ 3006 node1_buf_names = node1.get_buffer_names() 3007 node1_op_names = node1.get_operation_names() 3008 computed_deps: OrderedSet[Dep] = OrderedSet() 3009 why = WhyNoFuse(node1, node2) 3010 3011 for cd in node1.read_writes.writes: 3012 if not isinstance(cd, MemoryDep): 3013 continue 3014 for rd in node2.unmet_dependencies: 3015 if self.fusable_read_and_write(rd, cd): 3016 computed_deps.add(rd) 3017 3018 for dep in node2.unmet_dependencies: 3019 if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2): 3020 computed_deps.add(dep) 3021 3022 remaining_deps = OrderedSet( 3023 dep.name for dep in node2.unmet_dependencies - computed_deps 3024 ) 3025 if remaining_deps & node1_buf_names: 3026 # MemoryDeps didn't match and read different locations of the same buffer. 3027 # Examples here include: 3028 # - MemoryDep("foo", x) != MemoryDep("foo", x + 1) 3029 # - MemoryDep("foo", x) != StarDep("foo") 3030 why("memory deps did not match") 3031 return False 3032 for name in remaining_deps: 3033 op_name = self.name_to_buf[name].defining_op.get_name() 3034 if node1_op_names & self.name_to_fused_node[op_name].ancestors: 3035 why("intermediate nodes between node1 & node2") 3036 return False 3037 3038 return True 3039 3040 def fusable_weak_dep( 3041 self, weak_dep: WeakDep, node1: BaseSchedulerNode, node2: BaseSchedulerNode 3042 ) -> bool: 3043 if weak_dep.name not in node1.get_buffer_names(): 3044 return False 3045 3046 # A weak dep can be fused if and only if the fused operation acts inplace 3047 # on the buffer being mutated. i.e. the same index is being read then mutated 3048 mutating_writes = [ 3049 write 3050 for write in node2.read_writes.writes 3051 if write.name == weak_dep.mutating_buf 3052 ] 3053 if len(mutating_writes) != 1: 3054 return False 3055 write = mutating_writes[0] 3056 assert isinstance(write, MemoryDep) 3057 3058 if free_symbol_is_type(write.index, SymT.TMP): 3059 return False 3060 3061 real_name = self.mutation_real_name[weak_dep.mutating_buf] 3062 relevant_reads = [ 3063 read for read in node1.read_writes.reads if read.name == real_name 3064 ] 3065 return all( 3066 isinstance(read, MemoryDep) 3067 and not free_symbol_is_type(read.index, SymT.TMP) 3068 and read.index == write.index 3069 and read.size == write.size 3070 for read in relevant_reads 3071 ) 3072 3073 # StarDep doesn't match MemoryDep, different indices don't match 3074 # However, broadcasting sometimes strips dimensions, and if that's the case 3075 # we still can match unmet dep 3076 # if there's indirect indexing, don't match it 3077 def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: 3078 if isinstance(read, MemoryDep): 3079 if read.mode == write.mode and write.mode is not None: 3080 return True 3081 read_name = self.mutation_renames.get(read.name, read.name) 3082 3083 if ( 3084 read_name != write.name 3085 or free_symbol_is_type(read.index, SymT.TMP) 3086 or free_symbol_is_type(write.index, SymT.TMP) 3087 ): 3088 return False 3089 3090 if config.loop_ordering_after_fusion and read.num_vars != write.num_vars: 3091 # Need merge loops if we do loop ordering after fusion since 3092 # we have not merged the loops yet when creating the scheduler 3093 # nodes. 3094 read = read.normalize() 3095 write = write.normalize() 3096 3097 return ( 3098 read.index == write.index 3099 and len(read.size) >= len(write.size) 3100 and read.size[: len(write.size)] == write.size 3101 ) 3102 elif isinstance(read, StarDep): 3103 read_name = self.mutation_renames.get(read.name, read.name) 3104 write_name = self.mutation_renames.get(write.name, write.name) 3105 if ( 3106 read.mode == write.mode 3107 and write.mode is not None 3108 and read_name == write_name 3109 ): 3110 return True 3111 return False 3112 3113 def score_fusion( 3114 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 3115 ) -> Tuple[bool, bool, int, int]: 3116 """ 3117 Assign a score (higher comes first) to the fusion of node1 3118 and node2. When different fusions conflict with each other, 3119 this is the way we decide what order to run them in. 3120 3121 Our current score is based on: 3122 - Estimate of the saved memory operations 3123 - Fusions closer together in original order 3124 """ 3125 memory_score = self.score_fusion_memory(node1, node2) 3126 proximity_score = -max( 3127 abs(node1.min_order - node2.max_order), 3128 abs(node2.min_order - node1.max_order), 3129 ) 3130 return ( 3131 node1.is_template() == config.epilogue_fusion_first and memory_score > 0, 3132 node1.is_reduction() == node2.is_reduction() and memory_score > 0, 3133 memory_score, 3134 proximity_score, 3135 ) 3136 3137 def dep_size_hint(self, dep: Dep) -> int: 3138 res = 0 3139 if dep not in self.__dep_size_hint_cache: 3140 try: 3141 if not dep.has_unbacked_symbols(): 3142 res = dep.numbytes_hint() 3143 except KeyError: 3144 # In at least one test (test/inductor/test_torchbind.py) we 3145 # create a StarDep that doesn't exist in the graph and calling 3146 # `has_unbacked_symbols()` throws an error. 3147 pass 3148 self.__dep_size_hint_cache[dep] = res 3149 else: 3150 res = self.__dep_size_hint_cache[dep] 3151 return res 3152 3153 def score_fusion_memory( 3154 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 3155 ) -> int: 3156 """ 3157 The first term in our fusion score that estimates number of saved 3158 memory operations. 3159 """ 3160 node1_dep_len = len(node1.read_writes.reads) + len(node1.read_writes.writes) 3161 node2_dep_len = len(node1.read_writes.reads) + len(node2.read_writes.writes) 3162 3163 # optimization: iter over smaller set 3164 if max(node1_dep_len, node2_dep_len) * 4 > min(node1_dep_len, node2_dep_len): 3165 if node1_dep_len > node2_dep_len: 3166 tmp = node1 3167 node1 = node2 3168 node2 = tmp 3169 3170 deps = [] 3171 for dep in node1.read_writes.reads | node1.read_writes.writes: 3172 if dep in node2.read_writes.reads or dep in node2.read_writes.writes: 3173 deps.append(dep) 3174 3175 return sum(self.dep_size_hint(dep) for dep in deps) 3176 3177 common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( 3178 node2.read_writes.reads | node2.read_writes.writes 3179 ) 3180 return sum(self.dep_size_hint(dep) for dep in common_memory_deps) 3181 3182 def get_possible_fusions_with_highest_priority( 3183 self, possible_fusions: List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] 3184 ) -> List[Tuple[BaseSchedulerNode, BaseSchedulerNode]]: 3185 # Group the possible fusions based on their priority from the backend. 3186 # Only return the group of possible fusions with highest priority. 3187 if len(possible_fusions) == 0: 3188 return possible_fusions 3189 possible_fusions_group_by_priority: Dict[ 3190 int, List[Tuple[BaseSchedulerNode, BaseSchedulerNode]] 3191 ] = {} 3192 3193 for node1, node2 in possible_fusions: 3194 assert node1.get_device() == node2.get_device() 3195 device = node1.get_device() 3196 fusion_pair_priority = int( 3197 self.get_backend(device).get_fusion_pair_priority(node1, node2) 3198 ) 3199 if fusion_pair_priority not in possible_fusions_group_by_priority: 3200 possible_fusions_group_by_priority[fusion_pair_priority] = [ 3201 (node1, node2), 3202 ] 3203 else: 3204 possible_fusions_group_by_priority[fusion_pair_priority].append( 3205 (node1, node2) 3206 ) 3207 # return the possible fusions with highest priority 3208 possible_fusions_with_highest_priority = min( 3209 possible_fusions_group_by_priority.items(), key=operator.itemgetter(0) 3210 )[1] 3211 assert len(possible_fusions_with_highest_priority) > 0 3212 return possible_fusions_with_highest_priority 3213 3214 def score_fusion_key( 3215 self, nodes: Tuple[BaseSchedulerNode, BaseSchedulerNode] 3216 ) -> Tuple[bool, bool, int, int]: 3217 """ 3218 Shim for list.sort(key=...) 3219 """ 3220 node1, node2 = nodes 3221 return self.score_fusion(node1, node2) 3222 3223 def compute_last_usage(self) -> None: 3224 """ 3225 Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) 3226 """ 3227 3228 future_used_buffers: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) 3229 3230 for node in reversed(self.nodes): 3231 node.set_last_usage(future_used_buffers, self.mutation_real_name) 3232 future_used_buffers.update(node.last_usage) 3233 3234 def free_buffers(self) -> None: 3235 """Free any buffers that are no longer needed""" 3236 for name in sorted( 3237 self.buffer_names_to_free 3238 - V.graph.removed_buffers 3239 - V.graph.wrapper_code.freed 3240 ): 3241 if name in self.name_to_buf: 3242 buf = self.name_to_buf[name] 3243 if buf.can_free(): 3244 V.graph.wrapper_code.codegen_free(buf.node) 3245 elif name in V.graph.graph_inputs: 3246 storage = V.graph.graph_inputs[name].data 3247 assert isinstance(storage, ir.StorageBox) and storage.is_input_buffer() 3248 V.graph.wrapper_code.codegen_free(storage.data) 3249 3250 self.buffer_names_to_free.clear() 3251 3252 def remove_kernel_local_buffers(self) -> None: 3253 """ 3254 Any buffers that are both created and have a last use in the 3255 same kernel can be removed. 3256 """ 3257 3258 fused_node_names = OrderedSet( 3259 self.name_to_buf[buf].defining_op.get_name() 3260 for buf in V.kernel.store_buffer_names 3261 if buf in self.name_to_buf 3262 ) 3263 names_to_remove = [] 3264 for out_buf in V.kernel.store_buffer_names: 3265 if out_buf not in self.name_to_buf: 3266 # Aux buffers created during kernel codegen 3267 names_to_remove.append(out_buf) 3268 continue 3269 users = self.name_to_buf[out_buf].users 3270 assert users is not None 3271 users = OrderedSet(user.get_name() for user in users if not user.is_weak) 3272 if users.issubset(fused_node_names): 3273 names_to_remove.append(out_buf) 3274 3275 def remove_filter(n: str) -> bool: 3276 return ( 3277 n not in V.kernel.must_keep_buffers 3278 and n not in V.kernel.args.input_buffers 3279 and n not in self.mutation_renames 3280 and n not in self.mutation_real_name 3281 ) 3282 3283 names_to_remove = list(filter(remove_filter, names_to_remove)) 3284 3285 for name in names_to_remove: 3286 if name in V.kernel.args.inplace_buffers: 3287 buf = V.kernel.args.inplace_buffers[name] 3288 if isinstance(buf, str) and buf.startswith("REMOVED"): 3289 continue 3290 remove = all(n in names_to_remove for n in buf.other_names) 3291 if remove: 3292 self.remove_inplace_buffer(name) 3293 V.kernel.inplaced_to_remove.add(name) 3294 else: 3295 self.remove_buffer(name) 3296 3297 def remove_buffer(self, name: str) -> None: 3298 # Assign a special value instead of deleting the entry 3299 # because we still rely on output_buffers's length to 3300 # generate unique arg name. 3301 log.debug("remove_buffer(%r)", name) 3302 V.kernel.args.output_buffers[name] = "REMOVED" 3303 V.kernel.removed_buffers.add(name) 3304 3305 def remove_inplace_buffer(self, name: str) -> None: 3306 log.debug("removing_inplace_buffer(%r)", name) 3307 inner_name = V.kernel.args.inplace_buffers[name].inner_name 3308 V.kernel.args.inplace_buffers[name] = inner_name.replace( 3309 "in_out_ptr", "REMOVED" 3310 ) 3311 V.kernel.removed_buffers.add(name) 3312 3313 def flush(self) -> None: 3314 for backend in self.backends.values(): 3315 backend.flush() 3316 self.free_buffers() 3317 3318 def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode) -> None: 3319 assert isinstance(scheduler_node, ExternKernelSchedulerNode) 3320 # 'decide_inplace_update' stores the inplace update decisions in 3321 # the current kernel from where 'allocate' retrieve those decisions. 3322 # We have to make sure there is a non-NULL kernel handler to store 3323 # those inplace update decisions. 3324 counters["inductor"]["extern_calls"] += 1 3325 with V.set_kernel_handler(Kernel(increase_kernel_count=False)): 3326 scheduler_node.decide_inplace_update() 3327 scheduler_node.mark_run() 3328 node = scheduler_node.node 3329 assert isinstance(node, ir.ExternKernel), f"{type(node)=}" 3330 node.codegen(V.graph.wrapper_code) 3331 self.free_buffers() 3332 3333 def create_backend(self, device: torch.device) -> BaseScheduling: 3334 assert ( 3335 not is_gpu(device.type) or device.index is not None 3336 ), f"{device} should have been normalized in lowering" 3337 V.graph.add_device_info(device) 3338 3339 device_scheduling = get_scheduling_for_device(device.type) 3340 if device_scheduling is None: 3341 raise RuntimeError(f"Unsupported device type: {device.type}") 3342 3343 if not has_triton(): 3344 if ( 3345 device.type == "cuda" 3346 and (device_props := torch.cuda.get_device_properties(device)).major < 7 3347 ): 3348 raise RuntimeError( 3349 f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950 3350 ) 3351 elif is_gpu(device.type): 3352 raise RuntimeError( 3353 "Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950 3354 ) 3355 3356 return device_scheduling(self) 3357 3358 def get_backend(self, device: torch.device) -> BaseScheduling: 3359 if device not in self.backends: 3360 self.backends[device] = self.create_backend(device) 3361 return self.backends[device] 3362 3363 def enter_context(self, node: BaseSchedulerNode) -> None: 3364 def get_order(n: torch.fx.Node) -> int: 3365 if n not in self.origin_to_index: 3366 self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) 3367 return self.origin_to_index[n] 3368 3369 # Use a dict to have ordering 3370 origins = { 3371 (get_order(e), e): None 3372 for n in node.get_nodes() 3373 if n.node is not None 3374 for e in n.node.get_origins() 3375 } 3376 origins = list(origins.keys()) 3377 if origins: 3378 _, last = max(origins, key=operator.itemgetter(0)) 3379 V.graph.wrapper_code.enter_context(last) 3380 3381 def codegen(self) -> None: 3382 with dynamo_timed("Scheduler.codegen"): 3383 return self._codegen() 3384 3385 def _codegen(self) -> None: 3386 if config.check_stack_no_cycles_TESTING_ONLY: 3387 import torch._dynamo.convert_frame 3388 3389 stack = traceback.extract_stack() 3390 seen = set() 3391 for frame in reversed(stack): 3392 # This is where maybe_cprofile is 3393 if ( 3394 frame.name == "_compile_inner" 3395 and frame.filename == torch._dynamo.convert_frame.__file__ 3396 ): 3397 break 3398 key = (frame.filename, frame.lineno) 3399 assert key not in seen, ( 3400 f"Duplicate stack frame {frame.filename}:{frame.lineno}; " 3401 "did you add a decorator to one of the functions in this stack " 3402 "trace? If so, try using a context manager instead." 3403 ) 3404 seen.add(key) 3405 3406 for node in self.nodes: 3407 try: 3408 log.debug( 3409 "Generating code for node %s with estimated runtime %f", 3410 node.get_name(), 3411 node.get_estimated_runtime(), 3412 ) 3413 except Exception as e: 3414 log.debug( 3415 "Generating code for node %s with estimated runtime 0.0", 3416 node.get_name(), 3417 ) 3418 3419 self.enter_context(node) 3420 3421 if not isinstance(node, NopKernelSchedulerNode) and ( 3422 device := node.get_device() 3423 ): 3424 if ( 3425 device != self.current_device 3426 or node.is_extern() 3427 or node.is_template() 3428 ): 3429 self.flush() 3430 if device != self.current_device: 3431 if self.current_device and device_need_guard( 3432 self.current_device.type 3433 ): 3434 V.graph.wrapper_code.codegen_device_guard_exit() 3435 if device_need_guard(device.type): 3436 assert device.index is not None, "device should have an index" 3437 V.graph.wrapper_code.codegen_device_guard_enter(device.index) 3438 3439 self.current_device = device 3440 3441 self.buffer_names_to_free.update(node.last_usage) 3442 3443 if node.is_template(): 3444 node, *epilogue = node.get_nodes() 3445 self.get_backend(device).codegen_template(node, epilogue) 3446 elif node.is_extern(): 3447 node = typing.cast(ExternKernelSchedulerNode, node) 3448 self.codegen_extern_call(node) 3449 elif node.is_foreach(): 3450 node = typing.cast(ForeachKernelSchedulerNode, node) 3451 backend_ = self.get_backend(device) 3452 from .codegen.cuda_combined_scheduling import CUDACombinedScheduling 3453 from .codegen.simd import SIMDScheduling 3454 3455 if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)): 3456 backend = backend_ 3457 else: 3458 raise AssertionError(f"{type(self)=}") 3459 backend.codegen_combo_kernel(node) 3460 elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): 3461 self.get_backend(device).codegen_node(node) 3462 else: 3463 assert isinstance(node, NopKernelSchedulerNode) 3464 node.mark_run() 3465 3466 if config.triton.debug_sync_kernel: 3467 self.get_backend(device).codegen_sync() 3468 3469 self.available_buffer_names.update(node.get_buffer_names()) 3470 self.completed_operations.update(node.get_operation_names()) 3471 3472 if not isinstance(node, NopKernelSchedulerNode): 3473 device = node.get_device() 3474 if device is not None and self.get_backend(device).ready_to_flush(): 3475 self.flush() 3476 3477 if self.current_device and device_need_guard(self.current_device.type): 3478 # exit the outermost CUDA device guard. this is 3479 # important for nested indentation codegen-ing. 3480 V.graph.wrapper_code.codegen_device_guard_exit() 3481 3482 self.flush() 3483 3484 def benchmark_combo_kernel( 3485 self, node_list: Sequence[BaseSchedulerNode] 3486 ) -> Tuple[float, float, str]: 3487 """ 3488 Benchmark fused list of nodes and return the execution time 3489 in milliseconds on randomly generated inputs. 3490 """ 3491 device = node_list[0].get_device() 3492 V.graph.scheduler = self 3493 self.current_device = device 3494 backend = self.get_backend(device) 3495 return backend.benchmark_combo_kernel(node_list) 3496 3497 def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool: 3498 """ 3499 If config.benchmark_fusion is False, always return True. 3500 Otherwise, return True if fusion can brings speedup. 3501 """ 3502 if not config.benchmark_combo_kernel: 3503 return True 3504 3505 subkernel_nodes = nodes 3506 device = subkernel_nodes[0].get_device() 3507 3508 # don't support benchmark fusion for CPU right now. 3509 if device.type == "cpu": 3510 return True 3511 3512 from triton.compiler.errors import CompilationError 3513 3514 ms1, path1_list = 0.0, [] 3515 for i, snode in enumerate(subkernel_nodes): 3516 node_list = snode.get_nodes() 3517 # We can not accurately benchmark kernel using atomic_add 3518 # due to how we generate random integer inputs. 3519 if self._any_atomic_add(node_list): 3520 fusion_log.debug( 3521 "ComboKernel: benchmarking may not accurate due to atomic_add" 3522 ) 3523 3524 try: 3525 ms, path = self.benchmark_fused_nodes(node_list) 3526 if math.isinf(ms): 3527 fusion_log.debug( 3528 "ComboKernel benchmark: register spilling of %d-th subkernel", 3529 i, 3530 ) 3531 return False 3532 except CompilationError as e: 3533 # workaround triton issue: https://github.com/openai/triton/issues/2151 3534 if "Loop-carried variable" in str(e): 3535 fusion_log.debug( 3536 "ComboKernel benchmark: return True because of loop-carried variable" 3537 ) 3538 return True # allow fusion 3539 else: 3540 raise 3541 ms1 += ms 3542 path1_list.append(path) 3543 3544 try: 3545 ms2, ms2_clone, path2_list = self.benchmark_combo_kernel(subkernel_nodes) 3546 except CompilationError as e: 3547 # workaround triton issue: https://github.com/openai/triton/issues/2151 3548 if "Loop-carried variable" in str(e): 3549 fusion_log.debug( 3550 "ComboKernel benchmark: return True because of loop-carried variable" 3551 ) 3552 return True # allow fusion 3553 else: 3554 raise 3555 3556 # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking. 3557 small_kernel = ms2 - ms2_clone < 0.3 or ms1 < 0.3 3558 if fusion_log.isEnabledFor(logging.DEBUG): 3559 if ms1 > ms2 or small_kernel: 3560 fusion_log.debug( 3561 "can fuse (benchmark): fusing causes %sx speedup", 3562 green_text(f"{ms1 / ms2:.3f}"), 3563 ) 3564 else: 3565 fusion_log.debug( 3566 "cannot fuse (benchmark): fusing causes %sx slowdown", 3567 red_text(f"{ms1 / ms2:.3f}"), 3568 ) 3569 # ms1 returned by benchmark_fused_nodes discounted clone time 3570 return ms2 - ms2_clone < ms1 or small_kernel 3571 3572 def get_buffer_layout(self, buf_name: str) -> ir.Layout: 3573 buf = self.name_to_buf[buf_name] 3574 assert buf.node is not None 3575 return buf.node.get_layout() 3576 3577 def update_zero_dim_cpu_tensor(self) -> None: 3578 for node in self.nodes: 3579 if node.get_device() and is_gpu(node.get_device().type): 3580 for read in node.read_writes.reads: 3581 buffer = V.graph.name_to_buffer.get(read.name) 3582 if ( 3583 buffer 3584 and buffer.get_device() 3585 and buffer.get_device().type == "cpu" 3586 and not isinstance(buffer.layout, MultiOutputLayout) 3587 and buffer.get_size() == [] 3588 ): 3589 V.graph.zero_dim_cpu_tensor_list.add(read.name) 3590 3591 3592class BaseScheduling: 3593 @classmethod 3594 def get_backend_features(cls, device: torch.device) -> Sequence[BackendFeature]: 3595 """Return a set of .codegen.common.BackendFeature()""" 3596 return () 3597 3598 def can_fuse_vertical( 3599 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 3600 ) -> bool: 3601 """ 3602 Check whether node1 and node2 can be vertically fused or not. 3603 """ 3604 raise NotImplementedError 3605 3606 def can_fuse_horizontal( 3607 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 3608 ) -> bool: 3609 """ 3610 Check whether node1 and node2 can be horizontally fused or not. 3611 """ 3612 raise NotImplementedError 3613 3614 def fuse( 3615 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 3616 ) -> FusedSchedulerNode: 3617 """ 3618 Fuse two nodes 3619 """ 3620 if node1.is_foreach() or node2.is_foreach(): 3621 return ForeachKernelSchedulerNode.fuse(node1, node2) 3622 else: 3623 return FusedSchedulerNode.fuse(node1, node2) 3624 3625 def group_fn( 3626 self, sizes: Sequence[Sequence[sympy.Expr]] 3627 ) -> Tuple[Tuple[sympy.Expr, ...], ...]: 3628 """ 3629 Process the iteration sizes in case a transformation needs to be applied. 3630 """ 3631 raise NotImplementedError 3632 3633 def codegen_template( 3634 self, 3635 template_node: BaseSchedulerNode, 3636 epilogue_nodes: Sequence[BaseSchedulerNode], 3637 ) -> Optional[str]: 3638 """ 3639 Given a template node, generate a kernel. 3640 3641 This function is only available for triton now. If the third-party backend behaves as a sub-class 3642 of TritonScheduling, it can override it or reuse it. 3643 """ 3644 raise NotImplementedError 3645 3646 def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: 3647 """ 3648 Generate a kernel given a list of pre-fused nodes. 3649 """ 3650 raise NotImplementedError 3651 3652 def codegen_sync(self) -> None: 3653 """ 3654 Generate synchronization code for the kernel. This method depends on the hardware characteristics. 3655 """ 3656 raise NotImplementedError 3657 3658 def ready_to_flush(self) -> bool: 3659 """ 3660 Check whether the backend is requesting the scheduler to flush the generated kernel. 3661 If not supported, please return False. 3662 """ 3663 return False 3664 3665 def flush(self) -> None: 3666 """ 3667 Flush the generated kernel and python wrapper code to the source code file. 3668 """ 3669 raise NotImplementedError 3670 3671 def benchmark_fused_nodes( 3672 self, nodes: Sequence[BaseSchedulerNode] 3673 ) -> Tuple[float, str]: 3674 """ 3675 Benchmark fused list of nodes and return the execution time 3676 in milliseconds on randomly generated inputs. 3677 """ 3678 raise NotImplementedError 3679 3680 def get_fusion_pair_priority( 3681 self, node1: BaseSchedulerNode, node2: BaseSchedulerNode 3682 ) -> int: 3683 """ 3684 Return an unsigned integer which represents the priority of this fusion pair. 3685 The smaller is with higher priority. 3686 """ 3687 return 0 3688 3689 def benchmark_combo_kernel( 3690 self, node_list: Sequence[BaseSchedulerNode] 3691 ) -> Tuple[float, float, str]: 3692 """ 3693 Benchmark the list of nodes to combine and return the execution time 3694 and memory copy time in milliseconds on randomly generated inputs. 3695 """ 3696 raise NotImplementedError 3697 3698 3699def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[str]: 3700 lines = [] 3701 multi_template = node.get_template_node() 3702 assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) 3703 if multi_template and multi_template.make_kernel_render is None: 3704 lines.append(f"{node.get_name()} Unfinalized multi template buffer") 3705 else: 3706 from torch._inductor.codegen.cuda_combined_scheduling import ( 3707 CUDACombinedScheduling, 3708 ) 3709 3710 from .codegen.simd import SIMDScheduling 3711 3712 snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes 3713 device = snodes[0].get_device() 3714 backend = node.scheduler.get_backend(device) 3715 assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)) 3716 V.graph.scheduler.current_device = device 3717 3718 # Don't increment kernel count when generating debug string. 3719 # This will confuse some unit tests that check the number of 3720 # generated kernels. 3721 old_generated_kernel_count = metrics.generated_kernel_count 3722 triton_code = backend.generate_kernel_code_from_nodes(snodes).strip() 3723 metrics.generated_kernel_count = old_generated_kernel_count 3724 3725 lines.append(f"{node.get_name()} Triton code:") 3726 lines.append(textwrap.indent(triton_code, " ")) 3727 return lines 3728