1# mypy: allow-untyped-defs 2# pyre-strict 3from __future__ import annotations 4 5import heapq 6import operator 7import sys 8from collections import defaultdict 9from typing import Dict, List, Set, TYPE_CHECKING 10 11import torch 12 13from . import config, ir 14from .dependencies import WeakDep 15from .utils import ( 16 contains_collective, 17 contains_wait, 18 find_recursive_deps_of_node, 19 find_recursive_users_of_node, 20 is_collective, 21 is_fallback_op, 22 is_wait, 23) 24 25 26overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") 27 28if TYPE_CHECKING: 29 from .scheduler import BaseSchedulerNode 30 31 32def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: 33 """ 34 Greedily schedules waits as late as possible. 35 """ 36 return _schedule_for_comm( 37 snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False 38 ) 39 40 41def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: 42 """ 43 Greedily schedules comms as early as possible. 44 """ 45 return _schedule_for_comm( 46 snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False 47 ) 48 49 50def reorder_compute_for_overlap( 51 snodes: List[BaseSchedulerNode], 52) -> List[BaseSchedulerNode]: 53 """ 54 This achieves the following overall scheduling procedure: 55 Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes 56 that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N. 57 Step 2: If all those compute nodes are sufficient to overlap comm N, we're done. 58 Otherwise, we now need to look elsewhere to find compute that overlaps with comm N. 59 We prioritize compute nodes that are needed sooner. 60 Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1. 61 Step 4: We schedule comm N + 1. 62 Repeat this for subsequent comm nodes. 63 """ 64 return _schedule_for_comm( 65 snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True 66 ) 67 68 69def _schedule_for_comm( 70 snodes: List[BaseSchedulerNode], 71 raise_comms: bool, 72 sink_waits: bool, 73 reorder_for_overlap: bool, 74) -> List[BaseSchedulerNode]: 75 """ 76 Schedule `snodes` for various comm optimization objectives. 77 78 Args: 79 snodes: the nodes to be scheduled. 80 raise_comms: whether to greedily schedule collectives as early as possible 81 sink_wait: whether to greedily schedule waits as late as possible 82 reorder_compute_for_overlap: whether to reorder compute nodes to 83 optimize for compute/communication overlapping. 84 85 Returns: 86 The new schedule order. 87 88 Some notes on the synergy between different options: 89 - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`. 90 - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized. 91 """ 92 # We assign each node a tuple of scores (score_0, score_1, score_2), 93 # decreasing in importance, with a lower value indicating a higher ranking: 94 # 95 # - score_0: the lowest comm_idx among the comm nodes that the node blocks. 96 # If a node doesn't block any comm nodes, its score_0 is set to 97 # sys.maxsize. This score ensures that comm nodes get scheduled as early as 98 # possible. 99 # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures 100 # that wait nodes are deferred as late as possible. 101 # - score_2: the index of the node in the original topological order. This 102 # score provides stability in case of ties. 103 # 104 # When only raise_comms is True, only score_0 and score_2 are considered. 105 # When only sink_waits is True, only score_1 and score_2 are considered. 106 # When neither is True, the original order is yielded. 107 buf_name_to_snode = {} 108 name_to_fused_node = {} 109 scores_0, scores_1, scores_2 = {}, {}, {} 110 for idx, snode in enumerate(snodes): 111 for buf_name in snode.get_buffer_names(): 112 buf_name_to_snode[buf_name] = snode 113 114 for op_name in snode.get_operation_names(): 115 name_to_fused_node[op_name] = snode 116 name_to_fused_node[snode.get_name()] = snode 117 118 node_name = snode.get_name() 119 scores_0[node_name] = sys.maxsize 120 scores_1[node_name] = 0 121 scores_2[node_name] = idx 122 123 comm_idx = 0 124 for snode in snodes: 125 if raise_comms and contains_collective(snode): 126 scores_0[snode.get_name()] = comm_idx 127 for anc in snode.ancestors: 128 anc_fused_name = name_to_fused_node[anc].get_name() 129 scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx) 130 comm_idx += 1 131 elif sink_waits and contains_wait(snode): 132 scores_1[snode.get_name()] = 1 133 134 class Runnable: 135 def __init__(self, snode) -> None: 136 self.snode = snode 137 name = next(iter(snode.get_operation_names())) 138 fused_name = name_to_fused_node[name].get_name() 139 self.score = ( 140 scores_0[fused_name], 141 scores_1[fused_name], 142 scores_2[fused_name], 143 ) 144 145 def __lt__(self, other): 146 return self.score < other.score 147 148 unmet_deps: Dict[BaseSchedulerNode, Set[str]] = { 149 snode: {dep.name for dep in snode.unmet_dependencies} for snode in snodes 150 } 151 152 ready: List[Runnable] = [] 153 buffer_users: Dict[str, Set[BaseSchedulerNode]] = defaultdict(set) 154 snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes} 155 156 for snode, deps in unmet_deps.items(): 157 if len(deps) == 0: 158 heapq.heappush(ready, Runnable(snode)) 159 for dep in deps: 160 buffer_users[dep].add(snode) 161 162 scheduled = [] 163 164 def schedule(snode): 165 """ 166 Schedules `snode` and put all unblocked nodes onto the ready queue. 167 """ 168 scheduled.append(snode) 169 for buf_name in snode.get_buffer_names(): 170 for snode in buffer_users[buf_name]: 171 unmet_deps[snode].remove(buf_name) 172 if len(unmet_deps[snode]) == 0: 173 heapq.heappush(ready, Runnable(snode)) 174 175 def get_overlapping_candidate(): 176 """ 177 Return the next node in the ready queue that's neither a collective or 178 a wait. 179 """ 180 candidates = [ 181 x 182 for x in ready 183 if not contains_collective(x.snode) and not contains_wait(x.snode) 184 ] 185 if len(candidates) == 0: 186 return None 187 return min(candidates, key=lambda x: x.score) 188 189 def schedule_collective_for_overlap(snode): 190 """ 191 Schedules collective node `snode`, along with one or more compute nodes 192 to overlap with it. The strategy is described in the comment of 193 `reorder_compute_for_overlap`. 194 """ 195 assert contains_collective(snode) 196 schedule(snode) 197 198 collective_cost = snode_to_cost[snode] 199 while ( 200 collective_cost > 0 201 and (candidate := get_overlapping_candidate()) is not None 202 ): 203 ready.remove(candidate) 204 schedule(candidate.snode) 205 collective_cost -= snode_to_cost[candidate.snode] 206 heapq.heapify(ready) 207 208 while len(ready): 209 snode = heapq.heappop(ready).snode 210 if reorder_for_overlap and contains_collective(snode): 211 schedule_collective_for_overlap(snode) 212 else: 213 schedule(snode) 214 215 for snode, deps in unmet_deps.items(): 216 assert len(deps) == 0, ( 217 "Detected unscheduled nodes. " 218 f"Nodes with unmet dependencies: {unmet_deps}" 219 ) 220 return scheduled 221 222 223def decide_global_ordering_of_comms( 224 nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node 225) -> List[BaseSchedulerNode]: 226 """ 227 Decide global ordering of comms, by just enforcing the ordering that's in the input graph 228 (might not be the same ordering as the eager mode program). 229 TODO: Come up with a better approach 230 """ 231 # If FSDP2 is used, we apply FSDP-specific passes. 232 if any( 233 is_fallback_op( 234 x.node, 235 { 236 torch.ops.fsdp.all_gather_copy_in.default, 237 torch.ops.fsdp.chunk_cat.default, 238 }, 239 ) 240 for x in nodes 241 ): 242 nodes = enforce_comm_ordering_for_fsdp(nodes, name_to_buf, name_to_fused_node) 243 244 comm_nodes = [n for n in nodes if contains_collective(n)] 245 246 for i in range(1, len(comm_nodes)): 247 # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm 248 mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) 249 for buf in comm_nodes[i - 1].get_buffer_names(): 250 comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) 251 252 return nodes 253 254 255def estimate_op_runtime(snode: BaseSchedulerNode) -> float: 256 """ 257 Returns estimated op runtime in nanoseconds (ns) 258 """ 259 if config.estimate_op_runtime == "default": 260 runtime = snode.get_estimated_runtime() 261 else: 262 assert callable(config.estimate_op_runtime) 263 runtime = config.estimate_op_runtime(snode) 264 return runtime 265 266 267def node_summary(snode): 268 detail = "" 269 if isinstance(snode.node, ir.ExternKernelOut): 270 detail = f" ({snode.node.python_kernel_name})" 271 out_tensor_info = "" 272 if ( 273 hasattr(snode.node, "layout") 274 and hasattr(snode.node.layout, "size") 275 and hasattr(snode.node.layout, "stride") 276 ): 277 out_tensor_info = ( 278 f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})" 279 ) 280 node_name = "" 281 if hasattr(snode.node, "name"): 282 node_name = snode.node.name 283 return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})" 284 285 286def visualize_overlap(order): 287 total_est_runtime: float = 0.0 288 cur_comm_node = None 289 for snode in order: 290 if cur_comm_node is None: 291 if contains_collective(snode): 292 total_est_runtime += estimate_op_runtime(snode) 293 cur_comm_node = snode.node 294 elif is_wait(snode.node): 295 raise AssertionError( 296 "Wait is not expected when there is no collective running" 297 ) 298 else: # exposed compute op 299 total_est_runtime += estimate_op_runtime(snode) 300 overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 301 else: # cur_comm_node is not None 302 if contains_collective(snode): 303 raise AssertionError( 304 "Found two collectives running at the same time. " 305 "`visualize_overlap` needs to be updated to handle this case" 306 ) 307 elif is_wait(snode.node): # end of this comm op 308 overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 309 cur_comm_node = None 310 else: # overlapped compute op 311 overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004 312 overlap_log.debug( 313 f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 314 ) 315 316 317def reorder_compute_and_comm_for_overlap( 318 snodes: List[BaseSchedulerNode], 319) -> List[BaseSchedulerNode]: 320 order = snodes 321 322 for p in config.reorder_for_compute_comm_overlap_passes: 323 if isinstance(p, str) and p in globals(): 324 p = globals()[p] # it is a builtin pass 325 if torch.distributed.get_rank() == 0: 326 overlap_log.debug( 327 f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004 328 ) 329 try: 330 visualize_overlap(order) 331 except Exception as e: 332 overlap_log.debug(str(e)) 333 order = p(order) # type: ignore[operator] 334 if torch.distributed.get_rank() == 0: 335 overlap_log.debug( 336 f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004 337 ) 338 try: 339 visualize_overlap(order) 340 except Exception as e: 341 overlap_log.debug(str(e)) 342 return order 343 344 345def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: 346 try: 347 import torch.distributed._composable.fsdp._fsdp_collectives 348 349 assert torch.distributed.is_available() 350 # Assert existence of these ops 351 assert ( 352 torch.ops._c10d_functional.all_gather_into_tensor 353 and torch.ops._c10d_functional.all_gather_into_tensor_out 354 ) 355 except (ImportError, AttributeError, AssertionError): 356 return 357 358 from .pattern_matcher import ( 359 CallFunction, 360 KeywordArg, 361 Match, 362 PatternMatcherPass, 363 register_graph_pattern, 364 ) 365 366 """ 367 all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); 368 getitem = all_gather_copy_in[0]; 369 (getitem_1 = all_gather_copy_in[1];) # optional 370 371 all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...); 372 373 -> 374 375 all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); 376 getitem = all_gather_copy_in[0]; 377 getitem_1 = all_gather_copy_in[1]; 378 379 all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1); 380 """ 381 382 def remove_unused_getitem(g): 383 # Remove `getitem_X = all_gather_copy_in[1]` which is never used. 384 node_list = list(g.nodes) 385 for n in node_list: 386 if ( 387 n.target == operator.getitem 388 and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default 389 and n.args[1] == 1 390 ): 391 g.erase_node(n) 392 393 graph_pass = PatternMatcherPass() 394 395 @register_graph_pattern( 396 CallFunction( 397 torch.ops._c10d_functional.all_gather_into_tensor.default, 398 CallFunction( 399 operator.getitem, 400 CallFunction( 401 torch.ops.fsdp.all_gather_copy_in.default, 402 KeywordArg("all_gather_inputs"), 403 KeywordArg("inp_split_sizes"), 404 KeywordArg("all_gather_input_numel"), 405 KeywordArg("world_size"), 406 KeywordArg("rank"), 407 KeywordArg("dtype"), 408 KeywordArg("device"), 409 ), 410 KeywordArg("item_idx"), 411 ), 412 KeywordArg("group_size"), 413 KeywordArg("group_name"), 414 ), 415 pass_dict=graph_pass, 416 extra_check=lambda match: match.kwargs["item_idx"] == 0, 417 ) 418 def reinplace_all_gather(match: Match, *args, **kwargs): 419 def repl( 420 *args, 421 ): 422 copy_in_args = args[:-2] 423 group_size = args[-2] 424 group_name = args[-1] 425 all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default( 426 *copy_in_args 427 ) 428 getitem = all_gather_copy_in[0] 429 getitem_1 = all_gather_copy_in[1] 430 all_gather_into_tensor = ( 431 torch.ops._c10d_functional.all_gather_into_tensor_out.default( 432 getitem, group_size, group_name, out=getitem_1 433 ) 434 ) 435 return all_gather_into_tensor 436 437 match.replace_by_example( 438 repl, 439 [ 440 kwargs["all_gather_inputs"], 441 kwargs["inp_split_sizes"], 442 kwargs["all_gather_input_numel"], 443 kwargs["world_size"], 444 kwargs["rank"], 445 kwargs["dtype"], 446 kwargs["device"], 447 kwargs["group_size"], 448 kwargs["group_name"], 449 ], 450 ) 451 452 remove_unused_getitem(graph) 453 graph_pass.apply(graph) # type: ignore[arg-type] 454 455 456def get_op_idx(snode): 457 assert not isinstance( 458 snode, 459 ( 460 torch._inductor.scheduler.FusedSchedulerNode, 461 torch._inductor.scheduler.GroupedSchedulerNode, 462 ), 463 ) 464 return int(snode.get_name()[2:]) 465 466 467def enforce_comm_ordering_for_fsdp( 468 snodes: List[torch._inductor.scheduler.BaseSchedulerNode], 469 name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer], 470 name_to_fused_node: Dict[str, BaseSchedulerNode], 471) -> List[torch._inductor.scheduler.BaseSchedulerNode]: 472 from . import scheduler 473 474 new_order: list[BaseSchedulerNode] = [] 475 scheduled = set() 476 ag_exists = False 477 rs_exists = False 478 ag_grouped_node_to_wait_grouped_node = {} 479 rs_grouped_node_to_wait_grouped_node = {} 480 snode_name_to_final_snode = {} 481 482 def _create_group_node(snodes_to_group): 483 group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group) 484 for snode in snodes_to_group: 485 snode_name_to_final_snode[snode.get_name()] = group_node 486 snode_name_to_final_snode[group_node.get_name()] = group_node 487 return group_node 488 489 # Create grouped nodes for specific sets of ops 490 for snode in snodes: 491 # Case 1: Handle AllGather 492 if is_collective( 493 snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default 494 ) and any( 495 is_fallback_op( 496 name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default 497 ) 498 for x in snode.ancestors 499 ): 500 ag_exists = True 501 ag_snode = snode 502 ag_related_snode_set: set[scheduler.BaseSchedulerNode] = set() 503 504 # Find the "cast + copy_in + getitem + all_gather" code block 505 find_recursive_deps_of_node( 506 ag_snode, 507 ag_related_snode_set, 508 name_to_buf, 509 name_to_fused_node, 510 ) 511 512 # Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block 513 allowed_ops = { 514 torch.ops._c10d_functional.all_gather_into_tensor_out.default, 515 torch.ops._c10d_functional.wait_tensor.default, 516 torch.ops.fsdp.split_with_sizes_copy.default, 517 torch.ops.aten.set_.source_Tensor, 518 } 519 find_recursive_users_of_node( 520 ag_snode, 521 ag_related_snode_set, 522 name_to_buf, 523 name_to_fused_node, 524 criteria_cb=lambda x: not ( 525 isinstance(x, scheduler.NopKernelSchedulerNode) 526 or ( 527 isinstance(x, scheduler.ExternKernelSchedulerNode) 528 and x.node.op_overload in allowed_ops # type: ignore[union-attr] 529 ) 530 ), 531 ) 532 533 # sort nodes by original operation order 534 ag_related_snodes = sorted( 535 ag_related_snode_set, key=lambda x: get_op_idx(x) 536 ) 537 538 # In the "reuse layer" case, some ops in the 2nd all-gather code block could also 539 # depend on ops in the 1st all-gather code block, and we don't want to group them together. 540 end_idx_of_current_ag_block = len(ag_related_snodes) 541 copy_out_count = 0 542 for i in range(len(ag_related_snodes)): 543 cur_snode = ag_related_snodes[i] 544 if is_fallback_op( 545 cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default 546 ): 547 copy_out_count += 1 548 if copy_out_count > 1: 549 end_idx_of_current_ag_block = i 550 break 551 552 ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block] 553 554 # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode 555 wait_node_idx = None 556 for i in range(len(ag_related_snodes) - 1): 557 if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel): 558 wait_node_idx = i + 1 559 break 560 assert wait_node_idx is not None 561 ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx]) 562 563 # Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode 564 ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:]) 565 566 ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node 567 568 # Case 2: Handle ReduceScatter 569 elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default): 570 rs_exists = True 571 rs_snode = snode 572 573 # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block 574 rs_related_snode_set: set[scheduler.BaseSchedulerNode] = set() 575 find_recursive_users_of_node( 576 rs_snode, 577 rs_related_snode_set, 578 name_to_buf, 579 name_to_fused_node, 580 ) 581 582 # sort nodes by original operation order 583 rs_related_snodes = sorted( 584 rs_related_snode_set, key=lambda x: get_op_idx(x) 585 ) 586 587 # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode 588 wait_node_idx = None 589 for i in range(len(rs_related_snodes) - 1): 590 if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel): 591 wait_node_idx = i + 1 592 break 593 assert wait_node_idx is not None 594 rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx]) 595 596 # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode 597 rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:]) 598 599 rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node 600 601 assert len(snode_name_to_final_snode) > 0 602 if ag_exists: 603 assert len(ag_grouped_node_to_wait_grouped_node) > 0 604 if rs_exists: 605 assert len(rs_grouped_node_to_wait_grouped_node) > 0 606 607 # Build the new node schedule, taking GroupedSchedulerNode into account 608 for snode in snodes: 609 if snode.get_name() in snode_name_to_final_snode: 610 snode = snode_name_to_final_snode[snode.get_name()] 611 if snode in scheduled: 612 continue 613 new_order.append(snode) 614 scheduled.add(snode) 615 616 # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run 617 # before next AllGather's "copy_in then AG" group node 618 prev_ag_wait = None 619 for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items(): 620 if prev_ag_wait is not None: 621 mutating_buf = next(iter(ag_group_node.get_buffer_names())) 622 for o in prev_ag_wait.get_outputs(): 623 ag_group_node.add_fake_dep( 624 WeakDep(o.get_name(), mutating_buf=mutating_buf) 625 ) 626 prev_ag_wait = wait_group_node 627 628 # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run 629 # before next ReduceScatter's "copy_in then RS" group node 630 prev_rs_wait = None 631 for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items(): 632 if prev_rs_wait is not None: 633 mutating_buf = next(iter(rs_group_node.get_buffer_names())) 634 for o in prev_rs_wait.get_outputs(): 635 rs_group_node.add_fake_dep( 636 WeakDep(o.get_name(), mutating_buf=mutating_buf) 637 ) 638 prev_rs_wait = wait_group_node 639 640 return new_order # type: ignore[return-value] 641