1# mypy: allow-untyped-defs 2import operator 3from collections import defaultdict 4from dataclasses import dataclass, field 5from typing import Any, cast, Dict, List, Optional, Set 6 7import torch 8 9from .. import config, inductor_prims 10from ..pattern_matcher import ( 11 CallFunction, 12 Ignored, 13 KeywordArg, 14 ListOf, 15 Match, 16 MULTIPLE, 17 PatternExpr, 18 PatternMatcherPass, 19) 20 21 22aten = torch.ops.aten 23patterns = PatternMatcherPass() 24 25 26def _is_backward(graph: torch.fx.Graph) -> bool: 27 placeholders = [] 28 for node in graph.nodes: 29 if node.op != "placeholder": 30 break 31 placeholders.append(node) 32 return not all(node.name.startswith("primal") for node in placeholders) 33 34 35def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float: 36 return M * N * K / (M * K + N * K + M * N) 37 38 39def _filter_nodes_by_target(nodes: List[torch.fx.Node], target) -> List[torch.fx.Node]: 40 return [x for x in nodes if x.target == target] 41 42 43def _find_ancestors(node: torch.fx.Node) -> Set[torch.fx.Node]: 44 ancestors = set() 45 ancestors.add(node) 46 cur_nodes = [node] 47 while len(cur_nodes) > 0: 48 new_nodes = [] 49 for node in cur_nodes: 50 for inp in node.all_input_nodes: 51 if inp not in ancestors: 52 ancestors.add(inp) 53 new_nodes.append(inp) 54 cur_nodes = new_nodes 55 return {node for node in ancestors if node.op != "placeholder"} 56 57 58def _get_tensor(node: torch.fx.Node) -> torch.Tensor: 59 val = node.meta["val"] 60 assert isinstance(val, torch.Tensor) 61 return val 62 63 64@dataclass 65class _AllGatherMatch: 66 match: Match 67 shard_node: torch.fx.Node 68 ag_node: torch.fx.Node 69 res_node: torch.fx.Node 70 gather_dim: int 71 group_name: str 72 73 def replace_with(self, new_node: torch.fx.Node) -> None: 74 self.res_node.replace_all_uses_with(new_node) 75 76 def erase(self) -> None: 77 for node in reversed(self.match.nodes): 78 if len(node.users) == 0: 79 node.graph.erase_node(node) 80 81 82def find_all_gather_patterns(graph: torch.fx.Graph): 83 c10d = torch.ops._c10d_functional 84 85 def make_zero_dim_all_gather_pattern(shard): 86 return CallFunction( 87 c10d.wait_tensor.default, 88 CallFunction( 89 c10d.all_gather_into_tensor.default, 90 shard, 91 Ignored(), 92 KeywordArg("group_name"), 93 ), 94 ) 95 96 # Matches funcol.all_gather_tensor with gather_dim == 0 97 zero_dim_all_gather_pattern = make_zero_dim_all_gather_pattern(KeywordArg("shard")) 98 99 def make_all_gather_split_pattern(shard): 100 return CallFunction( 101 operator.getitem, 102 CallFunction( 103 aten.split.Tensor, 104 make_zero_dim_all_gather_pattern(shard), 105 Ignored(), 106 _users=MULTIPLE, 107 ), 108 Ignored(), 109 ) 110 111 def make_cat_pattern(splits): 112 return CallFunction( 113 aten.cat.default, 114 ListOf(splits), 115 KeywordArg("gather_dim"), 116 ) 117 118 # Matches funcol.all_gather_tensor with gather_dim > 0 119 non_zero_dim_all_gather_pattern = make_cat_pattern( 120 make_all_gather_split_pattern(KeywordArg("shard")), 121 ) 122 123 # Match a zero-dim all-gather in which the data is transferred as uint8 and 124 # viewed back as the original dtype. 125 zero_dim_type_erased_all_gather_pattern = CallFunction( 126 aten.view.dtype, 127 make_zero_dim_all_gather_pattern( 128 KeywordArg("shard"), 129 ), 130 Ignored(), 131 ) 132 133 # Match a non-zero dim all-gather in which the data is transferred as uint8 134 # and viewed back as the original dtype. 135 non_zero_dim_type_erased_all_gather_pattern = CallFunction( 136 aten.view.dtype, 137 make_cat_pattern( 138 CallFunction( 139 aten.view.dtype, 140 make_all_gather_split_pattern( 141 KeywordArg("shard"), 142 ), 143 Ignored(), 144 ), 145 ), 146 Ignored(), 147 ) 148 149 # If two patterns with the same res_node_target have the same suffix, the 150 # longer pattern should appear first in the list. 151 # e.g. supposed we have (1) A -> B -> C -> D and (2) B -> C -> D, (1) 152 # should appear before (2) in the list. 153 res_node_target_to_patterns = { 154 aten.cat.default: [ 155 (non_zero_dim_all_gather_pattern, 0), 156 ], 157 aten.view.dtype: [ 158 (non_zero_dim_type_erased_all_gather_pattern, 0), 159 (zero_dim_type_erased_all_gather_pattern, 0), 160 ], 161 c10d.wait_tensor.default: [ 162 (zero_dim_all_gather_pattern, 0), 163 ], 164 } 165 166 # Match in reverse to ensure longer patterns is prioritized 167 all_gathers = [] 168 visited_ag_nodes = set() 169 for node in reversed(graph.nodes): 170 for target, patterns in res_node_target_to_patterns.items(): 171 if node.target != target: 172 continue 173 for pattern, ag_node_idx in patterns: 174 match = pattern.match(node) 175 if not match: 176 continue 177 178 assert isinstance(match, Match) 179 ag_node = match.nodes[ag_node_idx] 180 assert ag_node.target == c10d.all_gather_into_tensor.default 181 182 if ag_node in visited_ag_nodes: 183 continue 184 visited_ag_nodes.add(ag_node) 185 186 ag_match = _AllGatherMatch( 187 match=match, 188 shard_node=match.kwargs["shard"], 189 ag_node=ag_node, 190 res_node=node, 191 gather_dim=match.kwargs.get("gather_dim", 0), 192 group_name=match.kwargs["group_name"], 193 ) 194 all_gathers.append(ag_match) 195 196 return list(reversed(all_gathers)) 197 198 199@dataclass 200class _ReduceScatterMatch: 201 match: Match 202 input_node: torch.fx.Node 203 rs_node: torch.fx.Node 204 res_node: torch.fx.Node 205 reduce_op: str 206 scatter_dim: int 207 group_name: str 208 209 def replace_with(self, new_node: torch.fx.Node) -> None: 210 self.res_node.replace_all_uses_with(new_node) 211 212 def erase(self) -> None: 213 for node in reversed(self.match.nodes): 214 if len(node.users) == 0: 215 node.graph.erase_node(node) 216 217 218def find_reduce_scatter_patterns(graph: torch.fx.Graph): 219 c10d = torch.ops._c10d_functional 220 221 def reduce_scatter_template(inp: PatternExpr): 222 return CallFunction( 223 c10d.wait_tensor.default, 224 CallFunction( 225 c10d.reduce_scatter_tensor.default, 226 inp, 227 KeywordArg("reduce_op"), 228 Ignored(), 229 KeywordArg("group_name"), 230 ), 231 ) 232 233 # Matches funcol.reduce_scatter_tensor with scatter_dim == 0 234 zero_dim_reduce_scatter_pattern = reduce_scatter_template(KeywordArg("input")) 235 236 # Matches funcol.reduce_scatter_tensor with scatter_dim > 0 237 non_zero_dim_reduce_scatter_pattern = reduce_scatter_template( 238 CallFunction( 239 aten.cat.default, 240 ListOf( 241 CallFunction( 242 operator.getitem, 243 CallFunction( 244 aten.split.Tensor, 245 KeywordArg("input"), 246 Ignored(), 247 KeywordArg("scatter_dim"), 248 _users=MULTIPLE, 249 ), 250 Ignored(), 251 ) 252 ), 253 ), 254 ) 255 256 reduce_scatters = [] 257 for node in reversed(graph.nodes): 258 if node.target == c10d.wait_tensor.default: 259 if match := non_zero_dim_reduce_scatter_pattern.match(node): 260 assert isinstance(match, Match) 261 reduce_scatters.append( 262 _ReduceScatterMatch( 263 match=match, 264 input_node=match.kwargs["input"], 265 rs_node=match.nodes[-2], 266 res_node=node, 267 reduce_op=match.kwargs["reduce_op"], 268 scatter_dim=match.kwargs["scatter_dim"], 269 group_name=match.kwargs["group_name"], 270 ) 271 ) 272 elif match := zero_dim_reduce_scatter_pattern.match(node): 273 assert isinstance(match, Match) 274 reduce_scatters.append( 275 _ReduceScatterMatch( 276 match=match, 277 input_node=match.kwargs["input"], 278 rs_node=match.nodes[0], 279 res_node=node, 280 reduce_op=match.kwargs["reduce_op"], 281 scatter_dim=0, 282 group_name=match.kwargs["group_name"], 283 ) 284 ) 285 return list(reversed(reduce_scatters)) 286 287 288@dataclass 289class _Matmul: 290 nodes: List[torch.fx.Node] 291 arg_ancestor_nodes: Set[torch.fx.Node] = field(init=False) 292 A_node: torch.fx.Node 293 B_node: torch.fx.Node 294 295 def __post_init__(self): 296 assert len(self.nodes) in (1, 3) 297 if len(self.nodes) == 1: 298 assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default) 299 else: 300 assert self.nodes[0].target == aten.reshape.default 301 assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default) 302 assert self.nodes[2].target == aten.reshape.default 303 self.arg_ancestor_nodes = _find_ancestors(self.B_node) 304 305 def replace_with(self, new_node: torch.fx.Node) -> None: 306 """ 307 Replace the matmul with the new node. 308 """ 309 graph = new_node.graph 310 311 # For 2D-matmuls, we simply replace the mm node with `new_node`. 312 if len(self.nodes) == 1: 313 mm_node = self.nodes[0] 314 assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) 315 mm_node.replace_all_uses_with(new_node) 316 graph.erase_node(mm_node) 317 return 318 319 # An ND-matmul is reshape -> mm -> reshape sequence. We first replace 320 # the second reshape node with `new_node`. Then, we ensure that the 321 # original mm node in the sequence ends up with zero users by replacing 322 # it with a reverse reshape of `new_node`. 323 graph = new_node.graph 324 assert len(self.nodes) == 3 325 mm_node = self.nodes[1] 326 output_reshape_node = self.nodes[2] 327 328 assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) 329 assert output_reshape_node.target == aten.reshape.default 330 331 output_reshape_node.replace_all_uses_with(new_node) 332 if len(mm_node.users) > 1: 333 with graph.inserting_after(new_node): 334 new_mm_node = graph.call_function( 335 aten.reshape.default, 336 args=(new_node, list(_get_tensor(mm_node).shape)), 337 ) 338 mm_node.replace_all_uses_with(new_mm_node) 339 340 def erase(self) -> None: 341 for node in reversed(self.nodes): 342 if len(node.users) == 0: 343 node.graph.erase_node(node) 344 345 @classmethod 346 def from_match(cls, match: List[torch.fx.Node]) -> "_Matmul": 347 assert len(match) in (1, 3) 348 assert match[0].target in ( 349 aten.mm.default, 350 aten.reshape.default, 351 ) 352 mm_node = match[0] if len(match) == 1 else match[1] 353 return _Matmul( 354 nodes=match, 355 A_node=cast(torch.fx.Node, match[0].args[0]), 356 B_node=cast(torch.fx.Node, mm_node.args[1]), 357 ) 358 359 360@dataclass 361class _ScaledMatmul(_Matmul): 362 A_scale_node: torch.fx.Node 363 B_scale_node: torch.fx.Node 364 bias_node: Optional[torch.fx.Node] 365 result_scale_node: Optional[torch.fx.Node] 366 out_dtype: Optional[torch.dtype] 367 use_fast_accum: bool 368 369 def __post_init__(self): 370 super().__post_init__() 371 self.arg_ancestor_nodes |= _find_ancestors(self.A_scale_node) 372 self.arg_ancestor_nodes |= _find_ancestors(self.B_scale_node) 373 374 @classmethod 375 def from_match(cls, match: List[torch.fx.Node]) -> "_ScaledMatmul": 376 assert len(match) in (1, 3) 377 assert match[0].target in ( 378 aten._scaled_mm.default, 379 aten.reshape.default, 380 ) 381 mm_node = match[0] if len(match) == 1 else match[1] 382 383 def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any: 384 if idx >= len(node.args): 385 return default 386 return node.args[idx] 387 388 return _ScaledMatmul( 389 nodes=match, 390 A_node=cast(torch.fx.Node, match[0].args[0]), 391 B_node=cast(torch.fx.Node, mm_node.args[1]), 392 A_scale_node=cast(torch.fx.Node, mm_node.args[2]), 393 B_scale_node=cast(torch.fx.Node, mm_node.args[3]), 394 bias_node=get_arg(mm_node, 4, None), 395 result_scale_node=get_arg(mm_node, 5, None), 396 out_dtype=get_arg(mm_node, 6, None), 397 use_fast_accum=get_arg(mm_node, 7, False), 398 ) 399 400 401def _find_reshape_mm_reshape(node: torch.fx.Node) -> List[_Matmul]: 402 if node.target != aten.reshape.default: 403 return [] 404 405 matches = [] 406 for mm_node in node.users: 407 if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): 408 continue 409 for reshape_node in mm_node.users: 410 if reshape_node.target != aten.reshape.default: 411 continue 412 413 # Since the reshape -> mm -> reshape pattern would be subsumed into 414 # the fused op, we only match the patterns where the shape of the 415 # second reshape is matches the mm result produced by the fused op. 416 matmul_input_node = cast(torch.fx.Node, node.args[0]) 417 B_node = cast(torch.fx.Node, mm_node.args[1]) 418 matmul_out_shape = torch.Size( 419 [ 420 *_get_tensor(matmul_input_node).shape[:-1], 421 _get_tensor(B_node).shape[-1], 422 ] 423 ) 424 if _get_tensor(reshape_node).shape != matmul_out_shape: 425 continue 426 matches.append([node, mm_node, reshape_node]) 427 # If for some rare reason mm_node is being reshaped by two 428 # different reshape nodes, we only include mm_node once in the 429 # parsing result. 430 break 431 432 matmuls = [] 433 for match in matches: 434 mm_node = match[1] 435 if mm_node.target == aten.mm.default: 436 matmul = _Matmul.from_match(match) 437 matmuls.append(matmul) 438 elif mm_node.target == aten._scaled_mm.default: 439 matmul = _ScaledMatmul.from_match(match) 440 matmuls.append(matmul) 441 else: 442 raise AssertionError( 443 "Expect the node's target to be either aten.mm.default or " 444 f"aten._scaled_mm.default. Got {mm_node.target}." 445 ) 446 return matmuls 447 448 449def _find_consumer_matmuls(node: torch.fx.Node) -> List[_Matmul]: 450 """ 451 Find the matmuls that use `node` as the lhs argument. 452 """ 453 matmuls = [] 454 for user in node.users: 455 # ND matmuls 456 if user.target == aten.reshape.default: 457 matmuls.extend(_find_reshape_mm_reshape(user)) 458 # 2D matmuls 459 elif user.target == aten.mm.default: 460 matmul = _Matmul.from_match(match=[user]) 461 matmuls.append(matmul) 462 elif user.target == aten._scaled_mm.default: 463 matmul = _ScaledMatmul.from_match([user]) 464 matmuls.append(matmul) 465 return matmuls 466 467 468def _insert_fused_all_gather_matmul( 469 graph: torch.fx.Graph, 470 matmuls: List[_Matmul], 471 shard_node: torch.fx.Node, 472 gather_dim: int, 473 group_name: str, 474) -> torch.fx.Node: 475 mm_types = set(map(type, matmuls)) 476 assert len(mm_types) == 1 477 mm_type = next(iter(mm_types)) 478 if mm_type == _Matmul: 479 B_nodes = [matmul.B_node for matmul in matmuls] 480 return graph.call_function( 481 torch.ops.symm_mem.fused_all_gather_matmul.default, 482 args=(shard_node, B_nodes, gather_dim, group_name), 483 ) 484 elif mm_type == _ScaledMatmul: 485 scaled_matmuls = cast(List[_ScaledMatmul], matmuls) 486 return graph.call_function( 487 torch.ops.symm_mem.fused_all_gather_scaled_matmul.default, 488 args=( 489 shard_node, 490 [matmul.B_node for matmul in scaled_matmuls], 491 scaled_matmuls[0].A_scale_node, 492 [matmul.B_scale_node for matmul in scaled_matmuls], 493 gather_dim, 494 group_name, 495 [matmul.bias_node for matmul in scaled_matmuls], 496 [matmul.result_scale_node for matmul in scaled_matmuls], 497 [matmul.out_dtype for matmul in scaled_matmuls], 498 [matmul.use_fast_accum for matmul in scaled_matmuls], 499 ), 500 ) 501 else: 502 raise AssertionError(f"Unexpected matmul match type: {mm_type}") 503 504 505def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: 506 """ 507 Fused the pattern 508 509 A = all_gather_tensor(A_shard, gather_dim, group_name) 510 C_0 = torch.matmul(A, B_0) 511 C_1 = torch.matmul(A, B_1) 512 C_2 = torch.matmul(A, B_2) 513 ... 514 515 into 516 517 A, Cs = torch.ops.symm_mem.fused_all_gather_matmul( 518 A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name, 519 ) 520 """ 521 if ( 522 not torch.distributed.is_available() 523 or not torch.distributed.is_nccl_available() 524 ): 525 return 526 527 c10d = torch.ops._c10d_functional 528 from torch.distributed._symmetric_memory import ( 529 is_symm_mem_enabled_for_group, 530 restride_A_shard_for_fused_all_gather_matmul, 531 ) 532 533 shard_node, ag_node, ag_res_node, gather_dim, group_name = ( 534 all_gather.shard_node, 535 all_gather.ag_node, 536 all_gather.res_node, 537 all_gather.gather_dim, 538 all_gather.group_name, 539 ) 540 541 if not is_symm_mem_enabled_for_group(group_name): 542 return 543 544 if gather_dim >= len(_get_tensor(shard_node).shape) - 1: 545 # Decomposing the matmul on the K dimension is not supported 546 return 547 548 # Find consumer matmuls 549 matmuls = _find_consumer_matmuls(ag_res_node) 550 551 # The matmuls are only fusible if non-A args don't depend on the all-gather 552 # result node 553 matmuls = [ 554 matmul 555 for matmul in matmuls 556 if all_gather.res_node not in matmul.arg_ancestor_nodes 557 ] 558 559 if len(matmuls) == 0 or len(set(map(type, matmuls))) != 1: 560 return 561 562 # Fuse the all_gather_tensor with the eligible matmuls 563 graph = ag_node.graph 564 with graph.inserting_before(ag_node): 565 if "val" in shard_node.meta: 566 restrided = restride_A_shard_for_fused_all_gather_matmul( 567 _get_tensor(shard_node), 568 gather_dim, 569 ) 570 shard_node = graph.call_function( 571 inductor_prims.force_stride_order, 572 args=(shard_node, restrided.stride()), 573 ) 574 575 fused_node = _insert_fused_all_gather_matmul( 576 graph, matmuls, shard_node, gather_dim, group_name 577 ) 578 new_ag_node = graph.call_function( 579 operator.getitem, 580 args=(fused_node, 0), 581 ) 582 new_out_nodes = graph.call_function( 583 operator.getitem, 584 args=(fused_node, 1), 585 ) 586 for idx, matmul in enumerate(matmuls): 587 new_out_node = graph.call_function( 588 operator.getitem, 589 args=(new_out_nodes, idx), 590 ) 591 matmul.replace_with(new_out_node) 592 matmul.erase() 593 all_gather.replace_with(new_ag_node) 594 all_gather.erase() 595 596 # Raise ancestors of non-A args that are topologically ordered between 597 # ag_res_node and the matmul above fused_node. 598 order = {node: idx for idx, node in enumerate(graph.nodes)} 599 nodes_to_raise = sorted( 600 {x for matmul in matmuls for x in matmul.arg_ancestor_nodes}, 601 key=lambda x: order[x], 602 ) 603 for node in nodes_to_raise: 604 if order[node] > order[fused_node]: 605 fused_node.prepend(node) 606 607 608def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]: 609 if node.target == aten.mm.default: 610 return _Matmul.from_match(match=[node]) 611 elif node.target == aten._scaled_mm.default: 612 return _ScaledMatmul.from_match(match=[node]) 613 elif node.target == aten.reshape.default: 614 reshape_node_1 = node 615 616 mm_node = reshape_node_1.args[0] 617 assert isinstance(mm_node, torch.fx.Node) 618 if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): 619 return None 620 621 reshape_node_0 = mm_node.args[0] 622 assert isinstance(reshape_node_0, torch.fx.Node) 623 if reshape_node_0.target != aten.reshape.default: 624 return None 625 626 if mm_node.target == aten.mm.default: 627 return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1]) 628 elif mm_node.target == aten._scaled_mm.default: 629 return _ScaledMatmul.from_match( 630 match=[reshape_node_0, mm_node, reshape_node_1] 631 ) 632 return None 633 634 635def _insert_fused_matmul_reduce_scatter( 636 graph: torch.fx.Graph, 637 matmul: _Matmul, 638 reduce_op: str, 639 scatter_dim: int, 640 group_name: str, 641) -> torch.fx.Node: 642 if type(matmul) == _Matmul: 643 return graph.call_function( 644 torch.ops.symm_mem.fused_matmul_reduce_scatter.default, 645 args=(matmul.A_node, matmul.B_node, reduce_op, scatter_dim, group_name), 646 ) 647 elif type(matmul) == _ScaledMatmul: 648 return graph.call_function( 649 torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default, 650 args=( 651 matmul.A_node, 652 matmul.B_node, 653 matmul.A_scale_node, 654 matmul.B_scale_node, 655 reduce_op, 656 scatter_dim, 657 group_name, 658 matmul.bias_node, 659 matmul.result_scale_node, 660 matmul.out_dtype, 661 matmul.use_fast_accum, 662 ), 663 ) 664 else: 665 raise AssertionError(f"Unexpected matmul match type: {type(matmul)}") 666 667 668def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: 669 """ 670 Fused the pattern 671 672 reduce_scatter_tensor(A @ B, scatter_dim, group_name) 673 674 into 675 676 torch.ops.symm_mem.fused_matmul_reduce_scatter( 677 A, B, scatter_dim, group_name, 678 ) 679 """ 680 if ( 681 not torch.distributed.is_available() 682 or not torch.distributed.is_nccl_available() 683 ): 684 return 685 686 c10d = torch.ops._c10d_functional 687 from torch.distributed._symmetric_memory import ( 688 is_symm_mem_enabled_for_group, 689 restride_A_for_fused_matmul_reduce_scatter, 690 ) 691 692 input_node, rs_node, rs_res_node, reduce_op, scatter_dim, group_name = ( 693 reduce_scatter.input_node, 694 reduce_scatter.rs_node, 695 reduce_scatter.res_node, 696 reduce_scatter.reduce_op, 697 reduce_scatter.scatter_dim, 698 reduce_scatter.group_name, 699 ) 700 701 if not is_symm_mem_enabled_for_group(group_name): 702 return 703 704 # Currently fused_matmul_reduce_scatter doesn't return the matmul result, 705 # so we can't apply the fusion if the matmul result is used by multiple 706 # users. This is not a fundamental limitation of the fused op and can be 707 # addressed if needed. 708 if len(input_node.users) != 1: 709 return 710 711 matmul = _find_producer_matmul(input_node) 712 if matmul is None: 713 return 714 715 if rs_res_node in matmul.arg_ancestor_nodes: 716 return 717 718 graph = rs_res_node.graph 719 with graph.inserting_before(rs_res_node): 720 if "val" in matmul.A_node.meta: 721 restrided = restride_A_for_fused_matmul_reduce_scatter( 722 _get_tensor(matmul.A_node), 723 scatter_dim, 724 ) 725 matmul.A_node = graph.call_function( 726 inductor_prims.force_stride_order, 727 args=(matmul.A_node, restrided.stride()), 728 ) 729 730 fused_node = _insert_fused_matmul_reduce_scatter( 731 graph, 732 matmul, 733 reduce_op, 734 scatter_dim, 735 group_name, 736 ) 737 reduce_scatter.replace_with(fused_node) 738 reduce_scatter.erase() 739 matmul.erase() 740 741 order = {node: idx for idx, node in enumerate(graph.nodes)} 742 nodes_to_raise = sorted( 743 matmul.arg_ancestor_nodes, 744 key=lambda x: order[x], 745 ) 746 for node in nodes_to_raise: 747 if order[node] > order[fused_node]: 748 fused_node.prepend(node) 749 750 751def _get_node_to_ancestors( 752 graph: torch.fx.Graph, 753) -> Dict[torch.fx.Node, Set[torch.fx.Node]]: 754 """ 755 Compute the ancestors for all nodes in a graph. 756 """ 757 node_to_ancestors = defaultdict(set) 758 for node in graph.nodes: 759 node_to_ancestors[node] = set(node.all_input_nodes) 760 for dep in node.all_input_nodes: 761 node_to_ancestors[node] |= node_to_ancestors[dep] 762 763 return node_to_ancestors 764 765 766def _get_collective_to_overlappable_nodes( 767 graph: torch.fx.Graph, 768) -> Dict[torch.fx.Node, List[torch.fx.Node]]: 769 """ 770 For each collective in the graph, find nodes that are neither ancestors nor 771 descendants of the collective. 772 """ 773 774 def is_collective(node) -> bool: 775 # Only consider all-gather and reduce-scatter in the context of 776 # micro-pipeline TP. 777 return node.target in [ 778 torch.ops._c10d_functional.all_gather_into_tensor.default, 779 torch.ops._c10d_functional.reduce_scatter_tensor.default, 780 ] 781 782 node_to_ancestors = _get_node_to_ancestors(graph) 783 collective_to_overlappable_nodes = defaultdict(list) 784 for node in graph.nodes: 785 if not is_collective(node): 786 continue 787 for x in graph.nodes: 788 if ( 789 node not in node_to_ancestors[x] 790 and x not in node_to_ancestors[node] 791 and x.op == "call_function" 792 ): 793 collective_to_overlappable_nodes[node].append(x) 794 795 return collective_to_overlappable_nodes 796 797 798def _get_unexposed_collectives(graph: torch.fx.Graph) -> List[torch.fx.Node]: 799 """ 800 Find all unexposed collectives in the graph. 801 802 Because we don't have the runtime estimate, this function is a rough 803 estimation using the following strong/hand-wavy assumptions: 804 805 - Only a predefined set of "compute intensive" operation can hide a collective. 806 - Any "compute intensive" operation can hide exactly one collective. 807 """ 808 809 def _is_compute_intensive(node: torch.fx.Node) -> bool: 810 return node.target in [torch.ops.aten.mm.default] 811 812 collective_to_overlapping_candidates = defaultdict(list) 813 available_nodes = set() 814 collective_to_overlappable_nodes = _get_collective_to_overlappable_nodes(graph) 815 for collective, overlappable_nodes in collective_to_overlappable_nodes.items(): 816 candidates = [x for x in overlappable_nodes if _is_compute_intensive(x)] 817 collective_to_overlapping_candidates[collective] = candidates 818 available_nodes |= set(candidates) 819 820 unexposed_collectives = [] 821 for ( 822 collective, 823 overlapping_candidates, 824 ) in collective_to_overlapping_candidates.items(): 825 # Each collective consumes exactly one overlapping candidate 826 for x in overlapping_candidates: 827 if x in available_nodes: 828 unexposed_collectives.append(collective) 829 available_nodes.remove(x) 830 break 831 return unexposed_collectives 832 833 834def micro_pipeline_tp_pass(graph: torch.fx.Graph): 835 all_gathers = find_all_gather_patterns(graph) 836 reduce_scatters = find_reduce_scatter_patterns(graph) 837 838 # When a collective can be hidden through either simple overlapping or 839 # micro-pipeline TP, we prefer simple overlapping to avoid the overhead 840 # associated with decomposition. If reorder_for_compute_comm_overlap is 841 # enabled, we identify collectives that can be hidden through simple 842 # overlapping and exclude them from micro-pipeline TP candidates. 843 if config.reorder_for_compute_comm_overlap: 844 unexposed_collectives = _get_unexposed_collectives(graph) 845 all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives] 846 reduce_scatters = [ 847 x for x in reduce_scatters if x.rs_node not in unexposed_collectives 848 ] 849 850 for all_gather in all_gathers: 851 fuse_all_gather_matmul(all_gather) 852 853 for reduce_scatter in reduce_scatters: 854 fuse_matmul_reduce_scatter(reduce_scatter) 855