1# mypy: allow-untyped-defs 2import itertools 3import logging 4import operator 5from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union 6from typing_extensions import TypeAlias 7 8import torch 9from torch._dynamo.utils import counters 10 11from ..pattern_matcher import ( 12 Arg, 13 CallFunction, 14 CallFunctionVarArgs, 15 CallMethodVarArgs, 16 FailedMatch, 17 get_arg_value, 18 Ignored, 19 KeywordArg, 20 ListOf, 21 Match, 22 MatchContext, 23 MULTIPLE, 24 PatternExpr, 25 PatternMatcherPass, 26 register_graph_pattern, 27 RepeatedExpr, 28) 29from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS 30 31 32log = logging.getLogger(__name__) 33 34_Arguments: TypeAlias = Tuple[torch.fx.node.Argument, ...] 35_TransformParam: TypeAlias = Tuple[ 36 Optional[_Arguments], 37 Optional[_Arguments], 38 Optional[_Arguments], 39 Optional[_Arguments], 40] 41_Range: TypeAlias = Tuple[int, int] 42 43 44PRE_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = {} 45POST_GRAD_PATTERNS: Dict[str, PatternMatcherPass] = {} 46 47pre_grad_pass_names = [ 48 "normalization_pass", 49 "remove_split_with_size_one_pass", 50 "merge_getitem_cat_pass", 51 "merge_stack_tahn_unbind_pass", 52 "merge_splits_pass", 53 "mutate_cat_pass", 54 "split_cat_pass", 55 "unbind_stack_pass", 56 "split_cat_to_slices_pass", 57 "unbind_cat_to_view_pass", 58 "split_stack_to_cats_pass", 59 "unbind_stack_to_slices_pass", 60 "move_reshape_out_of_split_stack_pass", 61] 62 63post_grad_pass_names = [ 64 "normalization_aten_pass", 65 "decompose_mm_pass", 66 "unbind_stack_aten_pass", 67 "shape_padding_multiplier", 68] 69 70for pass_name in pre_grad_pass_names: 71 # exclude all passes from the group batch fusion 72 # they do not use pattern matcher 73 if pass_name in PRE_GRAD_FUSIONS: 74 continue 75 PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass( 76 pass_name=pass_name, 77 ) 78 79for pass_name in post_grad_pass_names: 80 # exclude all passes from the group batch fusion 81 # they do not use pattern matcher 82 if pass_name in POST_GRAD_FUSIONS: 83 continue 84 POST_GRAD_PATTERNS[pass_name] = PatternMatcherPass( 85 pass_name=pass_name, 86 ) 87 88 89def construct_pattern_matcher_pass(pass_name: str): 90 """ 91 Return the specific pattern_matcher_pass given the pass name. 92 """ 93 if pass_name in PRE_GRAD_PATTERNS: 94 return PRE_GRAD_PATTERNS[pass_name] 95 else: 96 return POST_GRAD_PATTERNS[pass_name] 97 98 99def _get_split_args_default(split_node): 100 input_kwarg = "tensor" 101 split_size_kwarg = "split_size_or_sections" 102 dim_kwarg = "dim" 103 default_dim_value = 0 104 if split_node.op == "call_method": 105 split_size_kwarg = "split_size" 106 return ( 107 get_arg_value(split_node, 0, input_kwarg), 108 get_arg_value(split_node, 1, split_size_kwarg), 109 get_arg_value(split_node, 2, dim_kwarg) or default_dim_value, 110 ) 111 112 113def _get_dim(node: Any): 114 assert isinstance(node, torch.fx.Node) 115 if "dim" in node.kwargs: 116 assert isinstance(node.kwargs["dim"], int) 117 return node.kwargs["dim"] 118 if node.target == torch.unbind: 119 if len(node.args) == 2: 120 assert isinstance(node.args[-1], int) 121 return node.args[-1] 122 return 0 # defaults to dim=0 123 if node.target == torch.split: 124 if len(node.args) == 3: 125 assert isinstance(node.args[-1], int) 126 return node.args[-1] 127 return 0 # defaults to dim=0 128 raise AssertionError( 129 f"Can't extract `dim` from {node.target} {node.args} {node.kwargs}" 130 ) 131 132 133# noqa: W605 134# ############The pattern to be optimized is######### 135# unbind (dim=0) 136# / ... \ 137# getitem getitem -> user=1 138# | | 139# split split -> dim=1, user=1, split_section_size=1 140# | | 141# getitem getitem -> user=1 142# \ / 143# cat (dim=1) -> user=1 144# | 145 146# ################After transformation############# 147# unbind (dim=0) 148# / ... \ 149# getitem getitem -> user=1 150# \ / 151# cat (dim=1) -> user=1 152# | 153 154 155def normalize_split_base( 156 match: Match, 157 _get_split_args: Callable[ 158 [torch.fx.Node], Tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]] 159 ], 160): 161 """ 162 Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in 163 subsequent optimizations 164 """ 165 split_node = match.nodes[0] 166 graph = match.graph 167 split_input, split_size, split_dim = _get_split_args(split_node) 168 if split_input is None or split_dim is None or split_size is None: 169 log.debug("couldn't find split args") 170 return 171 if not is_node_meta_valid(split_node): 172 log.debug("example value absent for node: %s", split_node) 173 return 174 assert isinstance(split_node.meta["example_value"], (list, tuple)) 175 split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] 176 177 if any(isinstance(section, torch.SymInt) for section in split_sections): 178 # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. 179 return 180 if split_dim < 0: # Normalize split dim 181 split_dim += split_input.meta["example_value"].dim() 182 183 new_args = (split_input, split_sections) 184 new_kwargs = {"dim": split_dim} 185 if ( 186 split_node.args == new_args 187 and split_node.kwargs == new_kwargs 188 and split_node.op == "call_function" 189 ): 190 return 191 192 with graph.inserting_after(split_node): 193 new_split_node = graph.call_function( 194 torch.split, 195 args=new_args, 196 kwargs=new_kwargs, # type: ignore[arg-type] 197 ) 198 split_node.replace_all_uses_with(new_split_node) 199 new_split_node.meta.update(split_node.meta) 200 graph.erase_node(split_node) 201 counters["inductor"]["normalization_pass"] += 1 202 203 204@register_graph_pattern( 205 CallFunctionVarArgs(torch.split, users=MULTIPLE), 206 pass_dict=construct_pattern_matcher_pass("normalization_pass"), 207) 208@register_graph_pattern( 209 CallMethodVarArgs("split", users=MULTIPLE), 210 pass_dict=construct_pattern_matcher_pass("normalization_pass"), 211) 212def normalize_split_default(match: Match, *args, **kwargs): 213 return normalize_split_base(match, _get_split_args_default) 214 215 216@register_graph_pattern( 217 CallFunctionVarArgs(torch.split, users=MULTIPLE), 218 pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), 219) 220@register_graph_pattern( 221 CallMethodVarArgs("split", users=MULTIPLE), 222 pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), 223) 224def remove_split_with_size_one(match: Match, *args, **kwargs): 225 graph = match.graph 226 split_node = match.nodes[0] 227 split_input, split_size, split_dim = _get_split_args_default(split_node) 228 if split_input is None or split_dim is None or split_size is None: 229 log.debug("couldn't find split args") 230 return 231 if not is_node_meta_valid(split_node): 232 log.debug("example value absent for node: %s", split_node) 233 return 234 assert isinstance(split_node.meta["example_value"], (list, tuple)) 235 split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] 236 237 if any(isinstance(section, torch.SymInt) for section in split_sections): 238 # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. 239 return 240 # remove the dummy split whose split sections size is one 241 if len(split_sections) == 1: 242 # find the grand children of the split_node 243 next_users = find_next_users(split_node) 244 user = next(iter(split_node.users.keys())) 245 # replace the users of grand child node with the input node 246 for next_user in next_users: 247 next_user.replace_input_with(user, split_input) 248 # erase the split node and its child 249 graph.erase_node(user) 250 graph.erase_node(split_node) 251 counters["inductor"]["remove_split_with_size_one_pass"] += 1 252 253 254@register_graph_pattern( 255 CallFunctionVarArgs(torch.unbind, users=MULTIPLE), 256 pass_dict=construct_pattern_matcher_pass("normalization_pass"), 257) 258@register_graph_pattern( 259 CallMethodVarArgs("unbind", users=MULTIPLE), 260 pass_dict=construct_pattern_matcher_pass("normalization_pass"), 261) 262def normalize_unbind_default(match: Match, *args, **kwargs): 263 node = match.nodes[0] 264 graph = match.graph 265 input = get_arg_value(node, 0, "input") 266 dim = get_arg_value(node, 1, "dim") 267 if dim is None: 268 axis = node.kwargs.get("axis") 269 if axis is not None: 270 dim = axis 271 else: 272 dim = 0 273 if input is None: 274 log.debug("couldn't find unbind args") 275 return 276 if not is_node_meta_valid(input): 277 log.debug("example value absent for node: %s", input) 278 return 279 ndim = input.meta["example_value"].ndim 280 if dim < 0: # Normalize unbind dim 281 dim += ndim 282 with graph.inserting_after(node): 283 new_node = graph.call_function( 284 torch.unbind, 285 args=(input,), 286 kwargs={"dim": dim}, 287 ) 288 node.replace_all_uses_with(new_node) 289 new_node.meta.update(node.meta) 290 graph.erase_node(node) 291 counters["inductor"]["normalization_pass"] += 1 292 293 294@register_graph_pattern( 295 CallFunctionVarArgs(torch.cat, users=MULTIPLE), 296 pass_dict=construct_pattern_matcher_pass("normalization_pass"), 297) 298def normalize_cat_default(match: Match, *args, **kwargs): 299 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 300 301 cat_node = match.nodes[0] 302 graph = match.graph 303 tensors = get_arg_value(cat_node, 0, "tensors") 304 cat_dim = get_arg_value(cat_node, 1, "dim") 305 if cat_dim is None: 306 cat_axis = cat_node.kwargs.get("axis") 307 if cat_axis is not None: 308 cat_dim = cat_axis 309 else: 310 cat_dim = 0 311 if tensors is None or cat_dim is None: 312 log.debug("couldn't find cat args") 313 return 314 assert isinstance(tensors, (list, tuple)) 315 for tensor in itertools.chain([cat_node], tensors): 316 if not is_node_meta_valid(tensor): 317 log.debug("example value absent for node: %s", tensor) 318 return 319 320 ndim = cat_node.meta["example_value"].dim() 321 322 def is_empty_tensor(x): 323 # special case where torch.cat supports cat'ing with an empty tensor 324 x_shape = x.meta["example_value"].shape 325 return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0) 326 327 assert all( 328 ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors 329 ) 330 331 if cat_dim < 0: # Normalize cat dim 332 cat_dim += ndim 333 334 new_args = (tensors,) 335 new_kwargs = {"dim": cat_dim} 336 if ( 337 cat_node.args == new_args 338 and cat_node.kwargs == new_kwargs 339 and cat_node.op == "call_function" 340 ): 341 return 342 343 with graph.inserting_after(cat_node): 344 new_cat_node = graph.call_function( 345 torch.cat, 346 args=new_args, 347 kwargs=new_kwargs, 348 ) 349 cat_node.replace_all_uses_with(new_cat_node) 350 new_cat_node.meta.update(cat_node.meta) 351 graph.erase_node(cat_node) 352 counters["inductor"]["normalization_pass"] += 1 353 354 355@register_graph_pattern( 356 CallFunctionVarArgs(torch.stack, users=MULTIPLE), 357 pass_dict=construct_pattern_matcher_pass("normalization_pass"), 358) 359def normalize_stack_default(match: Match, *args, **kwargs): 360 node = match.nodes[0] 361 graph = match.graph 362 tensors = get_arg_value(node, 0, "tensors") 363 dim = get_arg_value(node, 1, "dim") or 0 364 if tensors is None or dim is None: 365 log.debug("couldn't find stack args") 366 return 367 assert isinstance(tensors, (list, tuple)) 368 369 # A bug in pytorch, some nodes miss the example_value metadata 370 for tensor in itertools.chain([node], tensors): 371 if not is_node_meta_valid(tensor): 372 log.debug("example value absent for node: %s", tensor) 373 return 374 375 ndim = node.meta["example_value"].dim() 376 if dim < 0: # Normalize dim 377 dim += ndim 378 379 with graph.inserting_after(node): 380 new_node = graph.call_function( 381 node.target, # type: ignore[arg-type] 382 args=(tensors,), 383 kwargs={"dim": dim}, 384 ) 385 node.replace_all_uses_with(new_node) 386 new_node.meta.update(node.meta) 387 graph.erase_node(node) 388 counters["inductor"]["normalization_pass"] += 1 389 390 391def find_next_users(split_node: torch.fx.Node) -> List[torch.fx.Node]: 392 next_users = [] 393 for getitem_node in split_node.users.keys(): 394 for getitem_user in getitem_node.users.keys(): 395 if getitem_user not in next_users: 396 next_users.append(getitem_user) 397 return next_users 398 399 400@register_graph_pattern( 401 CallMethodVarArgs("squeeze", users=MULTIPLE), 402 pass_dict=construct_pattern_matcher_pass("normalization_pass"), 403) 404def normalize_squeeze_default(match: Match, *args, **kwargs): 405 squeeze_node = match.nodes[0] 406 squeeze_input = get_arg_value(squeeze_node, 0) 407 408 if "dim" in squeeze_node.kwargs: 409 assert len(squeeze_node.args) == 1 410 dim = squeeze_node.kwargs["dim"] 411 elif len(squeeze_node.args) == 1: 412 # squeeze(Tensor) 413 dim = None 414 elif len(squeeze_node.args) == 2: 415 # squeeze(Tensor self, int dim) 416 # squeeze(Tensor self, int[] dim) 417 dim = squeeze_node.args[1] 418 else: 419 # squeeze(Tensor self, int[] dim) (called with varargs) 420 dim = squeeze_node.args[1:] 421 422 if isinstance(dim, Sequence) and len(dim) == 1: 423 dim = dim[0] 424 425 with match.graph.inserting_after(squeeze_node): 426 if dim is None: 427 new_squeeze_node = match.graph.call_function( 428 torch.squeeze, args=(squeeze_input,) 429 ) 430 else: 431 new_squeeze_node = match.graph.call_function( 432 torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim} 433 ) 434 squeeze_node.replace_all_uses_with(new_squeeze_node) 435 new_squeeze_node.meta.update(squeeze_node.meta) 436 match.graph.erase_node(squeeze_node) 437 438 439@register_graph_pattern( 440 CallMethodVarArgs("reshape", users=MULTIPLE), 441 pass_dict=construct_pattern_matcher_pass("normalization_pass"), 442) 443def normalize_reshape_default(match: Match, *args, **kwargs): 444 reshape_node = match.nodes[0] 445 if not is_node_meta_valid(reshape_node): 446 log.debug("example value absent for node: %s", reshape_node) 447 return 448 reshape_input = get_arg_value(reshape_node, 0) 449 450 with match.graph.inserting_after(reshape_node): 451 new_reshape_node = match.graph.call_function( 452 torch.reshape, 453 args=(reshape_input, tuple(reshape_node.meta["example_value"].shape)), 454 ) 455 reshape_node.replace_all_uses_with(new_reshape_node) 456 new_reshape_node.meta.update(reshape_node.meta) 457 match.graph.erase_node(reshape_node) 458 459 460class TorchSplit(CallFunction): 461 """ 462 Matches a call to torch.split if it is in a normalized form. Ensures that all users of 463 splits are unique getitems. 464 """ 465 466 def __init__(self, arg, sizes, func=torch.split) -> None: 467 # using KeywordArg("dim") for `dim` checks they all match 468 super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim")) 469 470 def _match(self, node: torch.fx.Node, ctx: MatchContext): 471 m = super()._match(node, ctx) 472 if not m: 473 return m 474 split_sections = node.args[1] 475 if not isinstance(split_sections, (list, tuple)): 476 return FailedMatch("split not normalized") 477 # check users are all unique getitems 478 seen_idxs = set() 479 for user in node.users: 480 if not CallFunction(operator.getitem, Arg(), Arg()).match(user): 481 # This should ideally never happen. Split user should always be a getitem 482 return FailedMatch(f"user of split not a getitem: {user}") 483 if not isinstance(user.args[1], int): 484 return FailedMatch("only integer getitems are handled") 485 if user.args[1] in seen_idxs: 486 return FailedMatch(f"duplicate getitem {user.args[1]}") 487 if user.args[-1] < 0: # type: ignore[operator] 488 # This shouldn't ideally happen as dynamo normalizes indexes to positive 489 return FailedMatch("negative index") 490 seen_idxs.add(user.args[1]) 491 return m 492 493 494@register_graph_pattern( 495 TorchSplit( 496 CallFunction( 497 operator.getitem, 498 TorchSplit( 499 KeywordArg("first_split_input"), 500 KeywordArg("first_split_sections"), 501 ), 502 Ignored(), 503 ), 504 KeywordArg("next_split_sections"), 505 ), 506 pass_dict=construct_pattern_matcher_pass("merge_splits_pass"), 507) 508def merge_splits( 509 match: Match, 510 first_split_input: torch.fx.Node, 511 first_split_sections: List[int], 512 next_split_sections: List[int], 513 # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim 514 dim: int, 515): 516 node = match.output_node() 517 # it is possible that the split has no users, 518 # we check the corner case and skip the pattern 519 if len(node.users.keys()) == 0: 520 return 521 graph = match.graph 522 first_split = node.args[0].args[0] # type: ignore[union-attr] 523 next_split_index = node.args[0].args[1] # type: ignore[union-attr] 524 525 new_split_sections = list(first_split_sections) 526 new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc] 527 528 first_split_dim = _get_dim(first_split) 529 530 to_remove = [] 531 532 with graph.inserting_before(first_split): # type: ignore[arg-type] 533 # Add the new split node 534 new_split = graph.call_function( 535 torch.split, 536 args=(first_split_input, new_split_sections), 537 kwargs={"dim": first_split_dim}, 538 ) 539 if is_node_meta_valid(first_split_input): 540 new_split.meta["example_value"] = torch.split( 541 first_split_input.meta["example_value"], 542 new_split_sections, 543 dim=first_split_dim, 544 ) 545 first_split_num_to_user = { 546 user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr] 547 } 548 549 new_split_num = 0 550 for split_num in range(len(first_split_sections)): 551 if split_num not in first_split_num_to_user: 552 new_split_num += 1 553 continue 554 old_getitem = first_split_num_to_user[split_num] 555 if split_num != next_split_index: 556 old_getitem.update_arg(0, new_split) 557 old_getitem.update_arg(1, new_split_num) 558 new_split_num += 1 559 else: 560 next_split_num_to_user = { 561 user.args[1]: user for user in node.users.keys() 562 } 563 # It is not necessary all getitems from the split node are used. 564 # We use the num of users to check the getitems to be merged. 565 for next_split_num in range(len(node.users.keys())): 566 with graph.inserting_after(new_split): 567 new_getitem = graph.call_function( 568 operator.getitem, args=(new_split, new_split_num) 569 ) 570 new_split_num += 1 571 next_getitem = next_split_num_to_user[next_split_num] 572 new_getitem.meta.update(next_getitem.meta) 573 next_getitem.replace_all_uses_with(new_getitem) 574 to_remove.append(next_getitem) 575 to_remove.append(node) 576 to_remove.append(old_getitem) 577 578 to_remove.append(first_split) # type: ignore[arg-type] 579 for node in to_remove: 580 graph.erase_node(node) 581 582 counters["inductor"]["merge_splits_pass"] += 1 583 584 585class SplitCatSimplifier: 586 """ 587 Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat" 588 pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat. 589 Some such cases are: 590 1. Final node has additional args (not coming from the initial split) 591 2. Shuffling of args between split/cat 592 3. Some final nodes are non-(cat/stack) 593 4. Split-dim != cat-dim (but equal split) 594 595 Note that any combination of the above cases can happen. 596 597 To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged. 598 Then, we simplify the split accordingly. In the best case, split can be entirely removed. 599 600 To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`). 601 602 Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added. 603 604 """ 605 606 def simplify( 607 self, 608 graph: torch.fx.Graph, 609 split_node: torch.fx.Node, 610 split_sections: List[int], 611 ): 612 # Find the next users (i.e. users after the getitem) 613 next_users = find_next_users(split_node) 614 # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by 615 # a tuple indicating the split ranges. See `get_user_input_list` for more details 616 user_inputs_list = self.get_user_input_list(split_node, next_users) 617 # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and 618 # we can simply replace the split node. Otherwise, we simplify it. 619 simplified_split_ranges = self.get_simplified_split_ranges( 620 split_sections, next_users, user_inputs_list 621 ) 622 if not simplified_split_ranges: # Simplification not possible 623 return 624 transform_params_list = self.get_transform_params( 625 split_node, next_users, user_inputs_list 626 ) 627 if not transform_params_list: 628 return 629 630 # Start actual replacement 631 user_inputs_list_new = self.replace_split( 632 graph, split_node, split_sections, user_inputs_list, simplified_split_ranges 633 ) 634 self.replace_cat( 635 graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type] 636 ) 637 self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] 638 counters["inductor"]["unbind_stack_pass"] += 1 639 640 def get_user_input_list( 641 self, split_node: torch.fx.Node, next_users: List[torch.fx.Node] 642 ) -> List[List[Union[torch.fx.Node, _Range]]]: 643 """ 644 Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner 645 list represents the inputs to that particular node. This list can either contain 646 - a tuple representing the ranges of get_items that should go into the cat (closed interval) 647 - torch.fx.Node representing "other" inputs (which are not coming from our split) 648 """ 649 user_inputs_list: List[List[Union[torch.fx.Node, _Range]]] = [] 650 for user in next_users: 651 if user.target in {torch.cat, torch.stack}: 652 user_inputs_list.append(self.get_merged_user_inputs(split_node, user)) 653 else: 654 user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type] 655 return user_inputs_list 656 657 def get_merged_user_inputs( 658 self, split_node: torch.fx.Node, cat_node: torch.fx.Node 659 ) -> List[Union[torch.fx.Node, _Range]]: 660 user_inputs = get_arg_value(cat_node, 0, "tensors") 661 simplified_user_inputs = [] 662 split_users = set(split_node.users.keys()) 663 for user_input in user_inputs: 664 if user_input not in split_users: 665 simplified_user_inputs.append(user_input) 666 else: 667 # Add which "getitem" cat depends on 668 simplified_user_inputs.append(user_input.args[1]) 669 return self.merge_consecutive_inputs(simplified_user_inputs) 670 671 def get_non_cat_node_input( 672 self, split_node: torch.fx.Node, node: torch.fx.Node 673 ) -> List[_Range]: 674 """ 675 Get input for a non cat node in the same format as `get_merged_user_inputs` 676 """ 677 node_input = [] 678 split_users = set(split_node.users.keys()) 679 for node_arg in node.all_input_nodes: 680 if node_arg in split_users: 681 getitem_num = get_arg_value(node_arg, 1) 682 node_input.append((getitem_num, getitem_num)) 683 return node_input 684 685 def merge_consecutive_inputs( 686 self, inputs: List[Union[torch.fx.Node, int]] 687 ) -> List[Union[torch.fx.Node, _Range]]: 688 """ 689 Merge consecutive inputs going into a user node. 690 691 For e.g. 692 [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1] 693 """ 694 merged_ranges = [] 695 cur_range = None 696 for input_ in inputs: 697 if isinstance(input_, int): 698 if not cur_range: 699 cur_range = [input_, input_] 700 elif input_ == cur_range[1] + 1: 701 cur_range[1] += 1 702 else: 703 merged_ranges.append(tuple(cur_range)) 704 cur_range = [input_, input_] 705 else: 706 if cur_range: 707 merged_ranges.append(tuple(cur_range)) 708 cur_range = None 709 merged_ranges.append(input_) # type: ignore[arg-type] 710 if cur_range: 711 merged_ranges.append(tuple(cur_range)) 712 return merged_ranges # type: ignore[return-value] 713 714 def get_simplified_split_ranges( 715 self, 716 split_sections, 717 next_users, 718 user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], 719 ) -> Optional[List[_Range]]: 720 ranges = set() 721 for user_node, user_inputs in zip(next_users, user_inputs_list): 722 ranges |= { 723 user_input 724 for user_input in user_inputs 725 if isinstance(user_input, tuple) 726 } 727 cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() 728 split_ranges = sorted( 729 [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges] 730 ) 731 732 if not self.has_non_overlapping_ranges( 733 split_ranges, 734 ): # This need not be a strict condition 735 # However, we keep it now for simplicity. 736 return None 737 split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1]) 738 if len(split_sections) == len(split_ranges): # Simplification not possible 739 return None 740 counters["inductor"]["scmerge_split_sections_removed"] = len( 741 split_sections 742 ) - len(split_ranges) 743 return split_ranges 744 745 def has_non_overlapping_ranges(self, ranges: List[_Range]) -> bool: 746 for range_, next_range in zip(ranges, ranges[1:]): 747 if range_[1] > next_range[0]: 748 return False 749 return True 750 751 def fill_gaps(self, ranges: List[_Range], min_: int, max_: int) -> List[_Range]: 752 cur = min_ 753 filled_ranges = [] 754 for a, b in ranges: 755 if cur < a: 756 filled_ranges.append((cur, a)) 757 filled_ranges.append((a, b)) 758 cur = b 759 if filled_ranges[-1][1] < max_: 760 filled_ranges.append((filled_ranges[-1][1], max_)) 761 return filled_ranges 762 763 def get_transform_params( 764 self, 765 split_node: torch.fx.Node, 766 next_users: List[torch.fx.Node], 767 user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], 768 ) -> Optional[List[List[_TransformParam]]]: 769 """ 770 Figure out what transforms are needed for each input to each cat node. 771 772 We replace a split node with an unflatten followed by a movedim 773 """ 774 split_dim = _get_dim(split_node) 775 split_sections = split_node.args[1] 776 transform_params_list: List[List[_TransformParam]] = [] 777 778 for user_node, user_inputs in zip(next_users, user_inputs_list): 779 if user_node.target not in {torch.cat, torch.stack}: 780 transform_params_list.append([]) 781 continue 782 783 cat_dim = get_arg_value(user_node, 1, "dim") 784 transform_params: List[_TransformParam] = [] 785 for user_input in user_inputs: 786 if split_dim == cat_dim and user_node.target == torch.cat: 787 # No transform needed 788 transform_params.append((None, None, None, None)) 789 elif isinstance(user_input, tuple): # Split being simplified 790 # Verify equal split 791 subset_split_sections = split_sections[ # type: ignore[index] 792 user_input[0] : user_input[1] + 1 793 ] 794 # All sections should be equal 795 if len(set(subset_split_sections)) != 1: 796 return None 797 798 num_splits = len(subset_split_sections) 799 unflatten_params = (split_dim, (num_splits, -1)) 800 movedim_params = ( 801 (split_dim, cat_dim) if split_dim != cat_dim else None 802 ) 803 transform_params.append( 804 (unflatten_params, movedim_params, None, None) 805 ) 806 elif ( 807 user_node.target == torch.stack or split_dim != cat_dim 808 ): # We need to unsqueeze inputs not coming through split 809 transform_params.append((None, None, (cat_dim,), None)) 810 else: # Non-split inputs 811 transform_params.append((None, None, None, None)) 812 transform_params_list.append(transform_params) 813 return transform_params_list 814 815 def replace_split( 816 self, 817 graph: torch.fx.Graph, 818 split_node: torch.fx.Node, 819 split_sections: List[int], 820 user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], 821 split_ranges: List[_Range], 822 ) -> List[List[torch.fx.Node]]: 823 """ 824 Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it 825 into a split with lesser sections if len(split_ranges) > 1. 826 827 Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node. 828 """ 829 split_input = split_node.args[0] 830 split_dim = _get_dim(split_node) 831 if len(split_ranges) == 1: # We can completely eliminate the split node 832 split_items = [split_input] 833 else: 834 with graph.inserting_after(split_node): 835 new_split = graph.call_function( 836 torch.split, 837 args=( 838 split_input, 839 [r[1] - r[0] for r in split_ranges], 840 ), 841 kwargs={"dim": split_dim}, 842 ) 843 if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr] 844 new_split.meta["example_value"] = torch.split( 845 split_input.meta["example_value"], [r[1] - r[0] for r in split_ranges], dim=split_dim # type: ignore[union-attr] 846 ) 847 counters["inductor"]["scmerge_split_added"] += 1 848 split_items = [] 849 with graph.inserting_after(new_split): 850 for i in range(len(split_ranges)): 851 getitem = graph.call_function(operator.getitem, args=(new_split, i)) 852 if is_node_meta_valid(new_split): 853 getitem.meta["example_value"] = new_split.meta["example_value"][ 854 i 855 ] 856 split_items.append(getitem) 857 # Now assign the right getitem to the right input 858 cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() 859 new_user_inputs_list = [] 860 for user_inputs in user_inputs_list: 861 new_user_inputs = [] 862 for user_input in user_inputs: 863 if isinstance(user_input, tuple): 864 # Find the correct new getitem (present in split_items) 865 new_user_inputs.append( 866 split_items[ 867 split_ranges.index( 868 ( 869 cumulative_sizes[user_input[0]], 870 cumulative_sizes[user_input[1] + 1], 871 ) 872 ) 873 ] 874 ) 875 else: 876 new_user_inputs.append(user_input) 877 new_user_inputs_list.append(new_user_inputs) 878 return new_user_inputs_list # type: ignore[return-value] 879 880 def replace_cat( 881 self, 882 graph: torch.fx.GraphModule, 883 split_node: torch.fx.Node, 884 next_users: List[torch.fx.Node], 885 user_inputs_list_new, 886 transform_params_list: List[List[_TransformParam]], 887 ): 888 split_dim = _get_dim(split_node) 889 split_users = split_node.users.keys() 890 new_cats = [] 891 for user_node, user_inputs_new, transform_params in zip( 892 next_users, user_inputs_list_new, transform_params_list 893 ): 894 if user_node.target not in {torch.cat, torch.stack}: 895 # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to 896 # the original split node) with the newer getitems 897 next_cat_input = 0 898 for input_node in user_node.all_input_nodes: 899 if input_node in split_users: 900 user_node.replace_input_with( 901 input_node, user_inputs_new[next_cat_input] 902 ) 903 next_cat_input += 1 904 continue 905 906 # Handle cat/stack user nodes 907 cat_dim = get_arg_value(user_node, 1, "dim") 908 user_inputs_new_transformed, user_inputs_new_transformed_meta = [], [] 909 # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them 910 to_stack, to_stack_meta = [], [] 911 stack_dim = None 912 with graph.inserting_before(user_node): 913 for user_input_new, transform_param in zip( 914 user_inputs_new, transform_params 915 ): 916 if not is_node_meta_valid(user_input_new): 917 log.debug("example value absent for node: %s", user_input_new) 918 return 919 # Apply transforms 920 ( 921 unflatten_params, 922 movedim_params, 923 unsqueeze_params, 924 flatten_params, 925 ) = transform_param 926 if unsqueeze_params and ( 927 stack_dim is None or stack_dim == unsqueeze_params[0] 928 ): 929 to_stack.append(user_input_new) 930 to_stack_meta.append(user_input_new.meta["example_value"]) 931 stack_dim = unsqueeze_params[0] 932 continue 933 elif to_stack: 934 stacked_input = graph.call_function( 935 torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} 936 ) 937 stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr] 938 to_stack, to_stack_meta = [], [] 939 stack_dim = None 940 user_inputs_new_transformed.append(stacked_input) 941 user_inputs_new_transformed_meta.append( 942 stacked_input.meta["example_value"] 943 ) 944 if unsqueeze_params: 945 to_stack.append(user_input_new) 946 stack_dim = unsqueeze_params[0] 947 to_stack_meta.append(user_input_new.meta["example_value"]) 948 continue 949 950 if unflatten_params: 951 user_input_new_meta = user_input_new.meta["example_value"] 952 user_input_new = graph.call_function( 953 torch.unflatten, args=(user_input_new, *unflatten_params) 954 ) 955 user_input_new.meta["example_value"] = torch.unflatten(user_input_new_meta, *unflatten_params) # type: ignore[arg-type, possibly-undefined, union-attr] 956 if movedim_params: 957 user_input_new_meta = user_input_new.meta["example_value"] 958 user_input_new = graph.call_function( 959 torch.movedim, args=(user_input_new, *movedim_params) 960 ) 961 user_input_new.meta["example_value"] = torch.movedim(user_input_new_meta, *movedim_params) # type: ignore[arg-type, possibly-undefined, union-attr] 962 if flatten_params: 963 user_input_new_meta = user_input_new.meta["example_value"] 964 user_input_new = graph.call_function( 965 torch.flatten, args=(user_input_new, *flatten_params) 966 ) 967 user_input_new.meta["example_value"] = torch.flatten(user_input_new_meta, *flatten_params) # type: ignore[arg-type, possibly-undefined, union-attr] 968 user_inputs_new_transformed.append(user_input_new) 969 user_inputs_new_transformed_meta.append( 970 user_input_new.meta["example_value"] 971 ) 972 if to_stack: 973 stacked_input = graph.call_function( 974 torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} 975 ) 976 stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr] 977 user_inputs_new_transformed.append(stacked_input) 978 user_inputs_new_transformed_meta.append( 979 stacked_input.meta["example_value"] 980 ) 981 982 with graph.inserting_after(user_node): 983 if len(user_inputs_new_transformed) > 1: 984 new_cat_node = graph.call_function( 985 torch.cat, 986 args=(user_inputs_new_transformed,), 987 kwargs={"dim": cat_dim}, 988 ) 989 new_cat_node.meta["example_value"] = torch.cat( 990 user_inputs_new_transformed_meta, dim=cat_dim 991 ) 992 counters["inductor"]["scmerge_cat_added"] += 1 993 else: 994 new_cat_node = user_inputs_new_transformed[-1] 995 new_cat_node.meta[ 996 "example_value" 997 ] = user_inputs_new_transformed_meta[-1] 998 999 if ( 1000 user_node.target == torch.cat 1001 and split_dim != cat_dim 1002 and split_node.target == torch.split 1003 ): 1004 with graph.inserting_after(new_cat_node): 1005 new_cat_node_meta = new_cat_node.meta["example_value"] 1006 new_cat_node = graph.call_function( 1007 torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1) 1008 ) 1009 new_cat_node.meta["example_value"] = torch.flatten(new_cat_node_meta, cat_dim, cat_dim + 1) # type: ignore[possibly-undefined, union-attr] 1010 user_node.replace_all_uses_with(new_cat_node) 1011 new_cats.append(new_cat_node) 1012 1013 def erase_old_nodes( 1014 self, 1015 graph: torch.fx.GraphModule, 1016 split_node: torch.fx.Node, 1017 next_users: List[torch.fx.Node], 1018 ): 1019 to_remove = [split_node] 1020 counters["inductor"]["scmerge_split_removed"] += 1 1021 to_remove.extend(split_node.users.keys()) 1022 for next_user in next_users: 1023 if next_user.target not in {torch.cat, torch.stack}: 1024 continue 1025 counters["inductor"]["scmerge_cat_removed"] += 1 1026 to_remove.append(next_user) 1027 for node in reversed(to_remove): 1028 if len(node.users.keys()) == 0: 1029 graph.erase_node(node) 1030 1031 1032class UnbindCatRemover(SplitCatSimplifier): 1033 """ 1034 Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier. 1035 1036 Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this, 1037 other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`, 1038 hence we extend that class. 1039 """ 1040 1041 def remove_unbind( 1042 self, 1043 graph: torch.fx.Graph, 1044 unbind_node: torch.fx.Node, 1045 ): 1046 if not is_node_meta_valid(unbind_node): 1047 return 1048 # we need to check if the getitem indices from unbind are consecutive and all go to the same cat node 1049 # before we do the unbind remove, otherwise it will hit the error when we unbind part of them 1050 getitem_indices = [] 1051 for getitem_node in unbind_node.users.keys(): 1052 getitem_indices.append(getitem_node.args[1]) 1053 if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type] 1054 getitem_indices 1055 ) != len( 1056 unbind_node.meta["example_value"] 1057 ): 1058 return 1059 num_unbind = len(getitem_indices) 1060 split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type] 1061 1062 super().simplify(graph, unbind_node, split_sections) 1063 1064 def get_simplified_split_ranges( 1065 self, 1066 split_sections: List[int], 1067 next_users: List[torch.fx.Node], 1068 user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], 1069 ) -> Optional[List[_Range]]: 1070 simplified_split_ranges = super().get_simplified_split_ranges( 1071 split_sections, next_users, user_inputs_list 1072 ) 1073 if not simplified_split_ranges or len(simplified_split_ranges) != 1: 1074 return None 1075 return simplified_split_ranges 1076 1077 def get_transform_params( 1078 self, 1079 split_node: torch.fx.Node, 1080 next_users: List[torch.fx.Node], 1081 user_inputs_list: List[List[Union[torch.fx.Node, _Range]]], 1082 ) -> Optional[List[List[_TransformParam]]]: 1083 """ 1084 Figure out what transforms are needed for each input to each cat node. 1085 1086 Here is the rough transforms we apply: 1087 1088 x -> unbind -> stack => x -> movedim 1089 1090 x -> unbind -> cat => x -> movedim -> flatten 1091 1092 When cat/stack nodes have additional args: 1093 1094 addn ---| addn -> unsqueeze ---| 1095 x -> unbind -> stack => x -> movedim -> cat 1096 1097 addn ---| addn ---| 1098 x -> unbind -> cat => x -> movedim -> flatten -> cat 1099 1100 (Note application of these depends on the dims as well) 1101 1102 1103 """ 1104 split_dim = _get_dim(split_node) 1105 transform_params_list: List[List[_TransformParam]] = [] 1106 for user_node, user_inputs in zip(next_users, user_inputs_list): 1107 cat_dim = get_arg_value(user_node, 1, "dim") or 0 1108 transform_params: List[_TransformParam] = [] 1109 for user_input in user_inputs: 1110 if isinstance(user_input, tuple): 1111 # User input is coming from unbind 1112 movedim_params = ( 1113 (split_dim, cat_dim) if split_dim != cat_dim else None 1114 ) 1115 flatten_params = None 1116 if user_node.target == torch.cat: 1117 flatten_params = (cat_dim, cat_dim + 1) 1118 transform_params.append( 1119 (None, movedim_params, None, flatten_params) 1120 ) 1121 elif ( 1122 user_node.target == torch.stack 1123 ): # We need to unsqueeze inputs not coming through unbind into cat 1124 transform_params.append((None, None, (cat_dim,), None)) 1125 else: # Non-unbind inputs 1126 transform_params.append((None, None, None, None)) 1127 transform_params_list.append(transform_params) 1128 return transform_params_list 1129 1130 1131class GetItem(CallFunction): 1132 def __init__(self, arg, index, _users=1) -> None: 1133 super().__init__(operator.getitem, arg, index, _users=_users) 1134 1135 def find_anchor_nodes(self, ctx: MatchContext, searched: Set[torch.fx.Node]): 1136 # We generally match GetItem with arg being an Arg(). So, we never return the anchor 1137 # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes 1138 # to not use ctx.pattern_to_node 1139 for pattern in self.flat_args_kwargs[0]: 1140 if isinstance(pattern, PatternExpr): 1141 for other_node in pattern.find_anchor_nodes(ctx, searched): 1142 if not isinstance(other_node, torch.fx.Node): 1143 continue 1144 for node in other_node.users: 1145 if node not in searched: 1146 if self._match_fns(node): 1147 yield node 1148 searched.add(node) 1149 1150 1151@register_graph_pattern( 1152 RepeatedExpr( 1153 CallFunction( 1154 torch.squeeze, 1155 GetItem( 1156 TorchSplit( 1157 KeywordArg("split_input"), 1158 KeywordArg("split_sizes"), 1159 ), 1160 Ignored(), 1161 ), 1162 KeywordArg("dim"), 1163 _users=MULTIPLE, 1164 ), 1165 ), 1166 pass_dict=construct_pattern_matcher_pass("split_cat_pass"), 1167) 1168@register_graph_pattern( 1169 RepeatedExpr( 1170 CallFunction( 1171 torch.squeeze, 1172 GetItem( 1173 TorchSplit( 1174 KeywordArg("split_input"), 1175 KeywordArg("split_sizes"), 1176 ), 1177 Ignored(), 1178 ), 1179 dim=KeywordArg("dim"), 1180 _users=MULTIPLE, 1181 ) 1182 ), 1183 pass_dict=construct_pattern_matcher_pass("split_cat_pass"), 1184) 1185def merge_split_squeeze( 1186 match: Match, split_input: torch.fx.Node, split_sizes: List[int], dim: int 1187): 1188 graph = match.graph 1189 split = next(node for node in match.nodes if node.target == torch.split) 1190 if not all(s == 1 for s in split_sizes): 1191 return 1192 if isinstance(dim, Sequence): 1193 return 1194 next_users = find_next_users(split) 1195 if not all(node.target == torch.squeeze for node in next_users): 1196 return 1197 with graph.inserting_before(match.output_node()): 1198 unbind = graph.call_function( 1199 torch.unbind, args=(split_input,), kwargs={"dim": dim} 1200 ) 1201 if is_node_meta_valid(split_input): 1202 unbind.meta["example_value"] = torch.unbind( 1203 split_input.meta["example_value"], dim=dim 1204 ) 1205 for item_index, getitem_node in sorted( 1206 [ 1207 (getitem_node.args[1], getitem_node) 1208 for getitem_node in split.users.keys() 1209 ] 1210 ): 1211 squeeze = next(iter(getitem_node.users.keys())) 1212 new_get_item = graph.call_function( 1213 operator.getitem, args=(unbind, item_index) 1214 ) 1215 squeeze.replace_all_uses_with(new_get_item) 1216 new_get_item.meta.update(squeeze.meta) 1217 graph.erase_node(squeeze) 1218 graph.erase_node(getitem_node) 1219 graph.erase_node(split) 1220 counters["inductor"]["split_cat_pass"] += 1 1221 1222 1223getitem_unbind = ListOf( 1224 GetItem( 1225 CallFunction( 1226 torch.unbind, 1227 KeywordArg("unbind_input"), 1228 dim=KeywordArg("dim"), 1229 _users=MULTIPLE, 1230 ), 1231 Ignored(), 1232 _users=MULTIPLE, 1233 ), 1234 partial=True, 1235) 1236 1237 1238@register_graph_pattern( 1239 CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE), 1240 pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), 1241) 1242@register_graph_pattern( 1243 CallFunction( 1244 [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE 1245 ), 1246 pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), 1247) 1248@register_graph_pattern( 1249 CallFunction( 1250 [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE 1251 ), 1252 pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), 1253) 1254def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): 1255 unbind_node = next(node for node in match.nodes if node.target == torch.unbind) 1256 UnbindCatRemover().remove_unbind(match.graph, unbind_node) 1257 1258 1259getitem_split = ListOf( 1260 CallFunction( 1261 operator.getitem, 1262 TorchSplit( 1263 Ignored(), 1264 KeywordArg("split_sections"), 1265 ), 1266 Ignored(), 1267 _users=MULTIPLE, 1268 ), 1269 partial=True, 1270) 1271 1272 1273reshape_getitem_split = ListOf( 1274 CallFunction( 1275 torch.reshape, 1276 CallFunction( 1277 operator.getitem, 1278 TorchSplit( 1279 Ignored(), 1280 KeywordArg("split_sections"), 1281 ), 1282 Ignored(), 1283 _users=MULTIPLE, 1284 ), 1285 Arg(), 1286 _users=MULTIPLE, 1287 ), 1288 partial=True, 1289) 1290 1291 1292@register_graph_pattern( 1293 CallFunction( 1294 [torch.stack, torch.cat], 1295 tensors=getitem_split, 1296 dim=Ignored(), 1297 _users=MULTIPLE, 1298 ), 1299 pass_dict=construct_pattern_matcher_pass("split_cat_pass"), 1300) 1301@register_graph_pattern( 1302 CallFunction( 1303 [torch.stack, torch.cat], 1304 getitem_split, 1305 dim=Ignored(), 1306 _users=MULTIPLE, 1307 ), 1308 pass_dict=construct_pattern_matcher_pass("split_cat_pass"), 1309) 1310@register_graph_pattern( 1311 CallFunction( 1312 [torch.stack, torch.cat], 1313 getitem_split, 1314 Ignored(), 1315 _users=MULTIPLE, 1316 ), 1317 pass_dict=construct_pattern_matcher_pass("split_cat_pass"), 1318) 1319def simplify_split_cat(match: Match, split_sections: List[int], dim: int): 1320 if not isinstance(split_sections, (list, tuple)): # Unnormalized split 1321 return 1322 split_node = next(node for node in match.nodes if node.target == torch.split) 1323 SplitCatSimplifier().simplify(match.graph, split_node, split_sections) 1324 1325 1326# noqa: W605 1327# ############pattern to be optimized is######### 1328 1329# split_node(dim=1) 1330# / \ ... / \ 1331# getitem getitem getitem getitem -> user=1 1332# \ / \ / 1333# cat (user=mul, dim=1) cat(user=mul, dim=1) 1334# | \ | \ 1335 1336# ################after transformation############# 1337 1338# split_node(dim=1) 1339# / ... \ 1340# getitem getitem 1341# | \ | \ 1342 1343 1344def has_same_parent_node(node: torch.fx.Node): 1345 # the input nodes of the node should come from the same parent 1346 prev_node = None 1347 for getitem in node.args[0]: # type: ignore[union-attr] 1348 if getitem.target != operator.getitem: # type: ignore[union-attr] 1349 return False 1350 if prev_node is None: 1351 prev_node = getitem.args[0] # type: ignore[union-attr] 1352 else: 1353 if getitem.args[0] != prev_node: 1354 return False 1355 return True 1356 1357 1358def remove_zeros(split_sections: List[int]): 1359 """ 1360 Remove zeros from the list and get the index mapping dict from getitem 1361 in split node to getitem in new split node 1362 """ 1363 new_split_sections, index_mapping = [], {} 1364 idx = 0 1365 for i in range(len(split_sections)): 1366 if split_sections[i] > 0: 1367 new_split_sections.append(split_sections[i]) 1368 index_mapping[i] = idx 1369 idx += 1 1370 1371 return new_split_sections, index_mapping 1372 1373 1374def is_sorted_and_consecutive(arr: List[int]) -> bool: 1375 # check if the array is sorted 1376 if arr == sorted(arr): 1377 # check if the differences between adjacent elements are all 1 1378 return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:])) 1379 else: 1380 return False 1381 1382 1383def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: List[int]) -> int: 1384 """ 1385 Calculate the fused tensor size in the indices 1386 """ 1387 fused_tensor_size = 0 1388 for i in range(len(split_node.args[1])): # type: ignore[arg-type] 1389 if i in indices: 1390 fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] 1391 return fused_tensor_size 1392 1393 1394@register_graph_pattern( 1395 CallFunction( 1396 torch.cat, 1397 getitem_split, 1398 dim=Ignored(), 1399 _users=MULTIPLE, 1400 ), 1401 pass_dict=construct_pattern_matcher_pass("merge_getitem_cat_pass"), 1402) 1403def merge_getitem_cat(match: Match, split_sections: List[int], dim: int): 1404 if not isinstance(split_sections, (list, tuple)): # Unnormalized split 1405 return 1406 graph = match.graph 1407 split_node = next(node for node in match.nodes if node.target == torch.split) 1408 split_input, split_size, split_dim = _get_split_args_default(split_node) 1409 # if the cat and split have different dims, return 1410 # Find the next users (i.e. users after the getitem) 1411 next_users = find_next_users(split_node) 1412 # 'immutable_list' object does not support mutation. Create a new copy of it 1413 split_sections = list(split_sections) 1414 for cat_user in next_users: 1415 if cat_user.target == torch.cat: 1416 cat_dim = get_arg_value(cat_user, 1, "dim") 1417 # check the all getitems in the cat_user from the same node 1418 # check the input of the cat has all getitem from the split 1419 # check all getitem only has one single user 1420 if ( 1421 split_dim != cat_dim 1422 or not has_same_parent_node(cat_user) 1423 or not all(len(arg.users) == 1 for arg in cat_user.args[0]) # type: ignore[union-attr] 1424 ): 1425 continue 1426 # find the index of getitems to be cated/stacked 1427 indices = [] 1428 for arg in cat_user.args[0]: # type: ignore[union-attr] 1429 indices.append(arg.args[1]) # type: ignore[union-attr] 1430 # the gettitems to be merged must be consecutive, otherwise 1431 # returned sliced tensor could be wrong 1432 if not is_sorted_and_consecutive(indices): 1433 continue 1434 # update the arg of cat user, only keep the first getitem 1435 cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index] 1436 # calculate the fused tensor sizes in the indices 1437 fused_tensor_size = 0 1438 for i in range(len(split_node.args[1])): # type: ignore[arg-type] 1439 if i in indices: 1440 fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] 1441 # update the split sections 1442 split_sections[indices[0]] = calculate_fused_tensor_size( 1443 split_node, indices 1444 ) 1445 # padding others with zeros to keep the same dict size 1446 for i in indices[1:]: 1447 split_sections[i] = 0 1448 # remove all unused indexes in the split_node 1449 new_split_sections, index_mapping = remove_zeros(split_sections) 1450 with graph.inserting_after(split_node): 1451 new_split_node = graph.call_function( 1452 torch.split, 1453 args=(split_input, split_sections), 1454 kwargs={"dim": split_dim}, 1455 ) 1456 split_node.replace_all_uses_with(new_split_node) 1457 new_split_node.meta.update(split_node.meta) 1458 # remove all unused getitem nodes 1459 to_remove = [cat_user] 1460 # dictionary keys changed during iteration 1461 new_split_getitem_nodes = list(new_split_node.users.keys()) 1462 for getitem_node in new_split_getitem_nodes: 1463 if getitem_node.args[1] in indices[1:]: 1464 to_remove.append(getitem_node) 1465 # update meta data of getitem 1466 elif getitem_node.args[1] == indices[0]: 1467 cat_user.replace_all_uses_with(getitem_node) 1468 getitem_node.meta.update(cat_user.meta) 1469 else: 1470 # update getitem index for new split node 1471 getitem_node.update_arg(1, index_mapping[getitem_node.args[1]]) 1472 graph.erase_node(split_node) 1473 for getitem_node in to_remove: 1474 graph.erase_node(getitem_node) 1475 # update the split sections of new split node 1476 new_split_node.update_arg(1, new_split_sections) 1477 split_node = new_split_node 1478 split_sections = new_split_sections 1479 1480 counters["inductor"]["merge_getitem_cat_pass"] += 1 1481 1482 1483# ############pattern to be optimized is######### 1484 1485# split_node(dim=1) -> user=multiple 1486# / \ ... / \ 1487# getitem getitem getitem getitem -> user=multiple 1488# \ \ / \ 1489# other_op /cat(user=mul, dim=1) other_op 1490# | 1491 1492# ################after transformation############# 1493 1494# split_node(dim=1) -> -> user=multiple 1495# / \ ... / \ 1496# getitem getitem getitem getitem -> user=multiple 1497# \ \ / \ 1498# other_op 1499 1500 1501@register_graph_pattern( 1502 CallFunction( 1503 torch.cat, 1504 getitem_split, 1505 dim=Ignored(), 1506 _users=MULTIPLE, 1507 ), 1508 pass_dict=construct_pattern_matcher_pass("mutate_cat_pass"), 1509) 1510def mutate_cat_node(match: Match, split_sections: List[int], dim: int): 1511 if not isinstance(split_sections, (list, tuple)): # Unnormalized split 1512 return 1513 graph = match.graph 1514 split_node = next(node for node in match.nodes if node.target == torch.split) 1515 split_input, split_size, split_dim = _get_split_args_default(split_node) 1516 # if the cat and split have different dims, return 1517 # Find the next users (i.e. users after the getitem) 1518 next_users = find_next_users(split_node) 1519 for cat_user in next_users: 1520 if cat_user.target == torch.cat: 1521 cat_dim = get_arg_value(cat_user, 1, "dim") or 0 1522 # check that all getitems in the cat_user from the same node 1523 # check the input of the cat has all getitem from the split 1524 if split_dim != cat_dim or not has_same_parent_node(cat_user): 1525 continue 1526 # find the index of getitems to be cat 1527 indices, idx_to_getitem = [], {} 1528 for getitem in cat_user.args[0]: # type: ignore[union-attr] 1529 indices.append(getitem.args[1]) # type: ignore[union-attr] 1530 idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr] 1531 # the gettitems to be merged must be consecutive, otherwise 1532 # returned sliced tensor could be wrong 1533 if not is_sorted_and_consecutive(indices): 1534 continue 1535 # case 1: the cat uses all getitems from the split 1536 if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type] 1537 # replace the users of the cat node to be the input of the split node 1538 cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type] 1539 # remove the cat node 1540 graph.erase_node(cat_user) 1541 counters["inductor"]["mutate_cat_pass"] += 1 1542 # case 2: the cat uses some getitems from the split 1543 elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] 1544 # check the split dim, and construct the slice tuple 1545 start_fused_size = calculate_fused_tensor_size( 1546 split_node, list(range(indices[0])) 1547 ) 1548 end_fused_size = start_fused_size + calculate_fused_tensor_size( 1549 split_node, indices 1550 ) 1551 slice_list = [] 1552 for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr] 1553 if i != split_dim: 1554 slice_list.append(slice(None, None, None)) 1555 else: 1556 slice_list.append(slice(start_fused_size, end_fused_size, None)) 1557 with graph.inserting_after(split_node): 1558 slice_node = graph.call_function( 1559 operator.getitem, 1560 args=(split_node.args[0], tuple(slice_list)), 1561 ) 1562 cat_user.replace_all_uses_with(slice_node) 1563 slice_node.meta.update(cat_user.meta) 1564 1565 # remove the cat node 1566 graph.erase_node(cat_user) 1567 counters["inductor"]["mutate_cat_pass"] += 1 1568 1569 1570@register_graph_pattern( 1571 CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE), 1572 pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), 1573) 1574def normalize_cat_default_aten(match: Match, *args, **kwargs): 1575 cat_node = match.nodes[0] 1576 graph = match.graph 1577 tensors = get_arg_value(cat_node, 0, "tensors") 1578 cat_dim = get_arg_value(cat_node, 1, "dim") 1579 if cat_dim is None: 1580 cat_axis = cat_node.kwargs.get("axis") 1581 if cat_axis is not None: 1582 cat_dim = cat_axis 1583 else: 1584 cat_dim = 0 1585 if tensors is None or cat_dim is None: 1586 log.debug("couldn't find cat args") 1587 return 1588 assert isinstance(tensors, (list, tuple)) 1589 for tensor in itertools.chain([cat_node], tensors): 1590 if "val" not in tensor.meta: 1591 log.debug("val absent for node: %s", tensor) 1592 return 1593 1594 ndim = cat_node.meta["val"].dim() 1595 1596 def is_empty_tensor(x: torch.fx.Node) -> bool: 1597 # special case where torch.ops.aten.cat.default supports cat'ing with an empty tensor 1598 x_shape = x.meta["val"].shape 1599 return len(x_shape) == 1 and x_shape[0] == 0 1600 1601 assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) 1602 1603 if cat_dim < 0: # Normalize cat dim 1604 cat_dim += ndim 1605 1606 with graph.inserting_after(cat_node): 1607 new_cat_node = graph.call_function( 1608 torch.ops.aten.cat.default, 1609 args=(tensors,), 1610 kwargs={"dim": cat_dim}, 1611 ) 1612 cat_node.replace_all_uses_with(new_cat_node) 1613 new_cat_node.meta.update(cat_node.meta) 1614 graph.erase_node(cat_node) 1615 counters["inductor"]["normalization_aten_pass"] += 1 1616 1617 1618@register_graph_pattern( 1619 CallFunction( 1620 torch.ops.aten.cat, 1621 ListOf(CallFunctionVarArgs(torch.ops.aten.unsqueeze)), 1622 _users=MULTIPLE, 1623 ), 1624 pass_dict=construct_pattern_matcher_pass("unbind_stack_aten_pass"), 1625) 1626def merge_unbind_stack_aten(match: Match, *args, **kwargs): 1627 node = match.nodes[-1] 1628 graph = match.graph 1629 # pyre-fixme[6] 1630 unsqueeze_nodes = list(node.args[0]) # type: ignore[arg-type] 1631 cat_dim = get_arg_value(node, 1, "dim") 1632 # check the unsqueeze nodes come from the select nodes 1633 if not all( 1634 get_arg_value(unsqueeze_node, 0, "input").target == torch.ops.aten.select 1635 for unsqueeze_node in unsqueeze_nodes 1636 ): 1637 return 1638 select_nodes = [ 1639 get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes 1640 ] 1641 parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") 1642 # check the target of select_nodes are the same 1643 if not all( 1644 select_node.target == torch.ops.aten.select for select_node in select_nodes 1645 ): 1646 return 1647 # check the select nodes come from the same parent node 1648 if not all( 1649 get_arg_value(select_node, 0, "input") == parent_of_select_node 1650 for select_node in select_nodes 1651 ): 1652 return 1653 if len(unsqueeze_nodes) != len(select_nodes): 1654 return 1655 # check the select nodes have the same dim 1656 if not all( 1657 get_arg_value(select_node, 1, "dim") == cat_dim for select_node in select_nodes 1658 ): 1659 return 1660 # check the select nodes have consecutive indices starting from 0 1661 if get_arg_value(select_nodes[0], 2, "index") != 0 or not is_sorted_and_consecutive( 1662 [get_arg_value(select_node, 2, "index") for select_node in select_nodes] 1663 ): 1664 return 1665 # check the users of parent of select node only from unsqueeze nodes that go to the cat node 1666 # we simply check the number of users of the parent of select node 1667 if len(parent_of_select_node.users.keys()) != len(node.args[0]): # type: ignore[arg-type] 1668 return 1669 node.replace_all_uses_with(parent_of_select_node) 1670 graph.erase_node(node) 1671 for unsqueeze_node in unsqueeze_nodes: 1672 graph.erase_node(unsqueeze_node) 1673 for select_node in select_nodes: 1674 if len(select_node.users) == 0: 1675 graph.erase_node(select_node) 1676 counters["inductor"]["unbind_stack_aten_pass"] += 1 1677 1678 1679def divide_into_consecutive_sublists(indices: List[int]) -> List[List[int]]: 1680 n = len(indices) 1681 if n <= 1: 1682 return [indices] 1683 1684 # Initialize the list of sublists 1685 sublists = [] 1686 1687 # Iterate over the indices 1688 i = 0 1689 while i < n: 1690 # Initialize the current sublist 1691 sublist = [indices[i]] 1692 1693 # Iterate over the remaining indices 1694 j = i + 1 1695 while j < n and indices[j] == indices[j - 1] + 1: 1696 # Add the next index to the current sublist 1697 sublist.append(indices[j]) 1698 j += 1 1699 1700 # Add the current sublist to the list of sublists 1701 sublists.append(sublist) 1702 # Move to the next index 1703 i = j 1704 1705 return sublists 1706 1707 1708def update_args_from_split_getitem( 1709 graph: torch.fx.Graph, 1710 node: torch.fx.Node, 1711 getitem_indices: List[int], 1712 parents_seen: List[torch.fx.Node], 1713 new_cat_args: List[torch.fx.Node], 1714 new_cat_args_meta: List[torch.fx.Node], 1715 idx_to_getitems: Dict[int, torch.fx.Node], 1716 threshold_to_cat: int = 2, 1717): 1718 split_input, split_size, split_dim = _get_split_args_default(parents_seen[-1]) 1719 # case 1: the number of getitems is the same as the split size, elimiate the split 1720 if len(split_size) == len(getitem_indices) and is_sorted_and_consecutive( 1721 getitem_indices 1722 ): 1723 # we can merge the getitems from the previous parent 1724 new_cat_args.append(split_input) 1725 new_cat_args_meta.append(split_input.meta["example_value"]) 1726 else: 1727 if len(getitem_indices) > 0: 1728 # case 2: the number of getitems is smaller than the split size but larger than the threshold, and 1729 # the indices of getitems are not all consecutive, we need to divide the indices into multiple groups 1730 geitem_indices_sublist = divide_into_consecutive_sublists(getitem_indices) 1731 for sublist in geitem_indices_sublist: 1732 if len(sublist) >= threshold_to_cat: 1733 # case 2: the number of getitems is smaller than the split size but larger than the threshold 1734 # we need to slice the input of parent 1735 start_fused_size = sum(split_size[: sublist[0]]) 1736 end_fused_size = sum(split_size[: sublist[-1] + 1]) 1737 slice_list = [] 1738 for i in range(len(split_input.meta["example_value"].shape)): # type: ignore[union-attr] 1739 if i != split_dim: 1740 slice_list.append(slice(None, None, None)) 1741 else: 1742 slice_list.append( 1743 slice(start_fused_size, end_fused_size, None) 1744 ) 1745 with graph.inserting_after(node): 1746 slice_node = graph.call_function( 1747 operator.getitem, 1748 args=(split_input, tuple(slice_list)), 1749 ) 1750 slice_node.meta["example_value"] = split_input.meta[ 1751 "example_value" 1752 ][tuple(slice_list)] 1753 new_cat_args.append(slice_node) 1754 new_cat_args_meta.append(slice_node.meta["example_value"]) 1755 else: 1756 # case 3: the number of getitems is smaller than the threshold, no merge is done 1757 # get the getitems based on the indexes 1758 for i in sublist: 1759 new_cat_args.append(idx_to_getitems[i]) 1760 new_cat_args_meta.append( 1761 idx_to_getitems[i].meta["example_value"] 1762 ) 1763 1764 1765def reshape_cat_node( 1766 graph: torch.fx.Graph, 1767 cat_node: torch.fx.Node, 1768 unbind_input: torch.fx.Node, 1769 cat_dim: int, 1770 unbind_dim: int, 1771 cat_shape: torch.Size, 1772) -> torch.fx.Node: 1773 if cat_dim != unbind_dim: 1774 # construct the permute node args, which has the same shape as the slice node 1775 # then it has the same dim as the unbind_input, i.e., shape of cat + 1 1776 with graph.inserting_after(cat_node): 1777 permute_list = list(range(len(cat_shape) + 1)) 1778 permute_list[unbind_dim], permute_list[cat_dim] = ( 1779 permute_list[cat_dim], 1780 permute_list[unbind_dim], 1781 ) 1782 permute_node = graph.call_function( 1783 torch.permute, 1784 args=(unbind_input, permute_list), 1785 ) 1786 permute_node.meta["example_value"] = torch.permute( 1787 unbind_input.meta["example_value"], permute_list 1788 ) # type: ignore[arg-type] 1789 else: 1790 permute_node = unbind_input 1791 with graph.inserting_after(permute_node): 1792 reshape_node = graph.call_function( 1793 torch.reshape, args=(permute_node, tuple(cat_shape)) 1794 ) 1795 reshape_node.meta["example_value"] = torch.reshape( 1796 permute_node.meta["example_value"], tuple(cat_shape) 1797 ) # type: ignore[arg-type] 1798 return reshape_node 1799 1800 1801def update_args_from_unbind_getitem( 1802 graph: torch.fx.Graph, 1803 node: torch.fx.Node, # cat or stack node 1804 getitem_indices: List[int], 1805 parents_seen: List[torch.fx.Node], 1806 new_cat_args: List[torch.fx.Node], 1807 new_cat_args_meta: List[torch.fx.Node], 1808 idx_to_getitems: Dict[int, torch.fx.Node], 1809 threshold_to_cat: int = 2, 1810): 1811 unbind_input = get_arg_value(parents_seen[-1], 0, "input") # split or unbind input 1812 unbind_dim = get_arg_value(parents_seen[-1], 1, "dim") # split or unbind dim 1813 cat_dim = get_arg_value(node, 1, "dim") # cat or stack dim 1814 # case 1: the number of getitems is the same as the split size, elimiate the split 1815 size = list(unbind_input.meta["example_value"].shape)[unbind_dim] 1816 if size == len(getitem_indices): 1817 cat_shape = torch.cat( 1818 [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], 1819 dim=cat_dim, 1820 ).shape 1821 # we can merge the getitems from the previous parent 1822 reshape_node = reshape_cat_node( 1823 graph, node, unbind_input, cat_dim, unbind_dim, cat_shape 1824 ) 1825 new_cat_args.append(reshape_node) 1826 new_cat_args_meta.append(reshape_node.meta["example_value"]) 1827 elif len(getitem_indices) >= threshold_to_cat and is_sorted_and_consecutive( 1828 getitem_indices 1829 ): 1830 # case 2: the number of getitems is smaller than the split size but larger than the threshold 1831 # we need to slice the input of parent 1832 cat_shape = torch.cat( 1833 [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], 1834 dim=cat_dim, 1835 ).shape 1836 slice_list = [] 1837 for i in range(len(cat_shape) + 1): 1838 if i != unbind_dim: 1839 slice_list.append(slice(None, None, None)) # start, end, step 1840 else: 1841 slice_list.append( 1842 slice(getitem_indices[0], getitem_indices[-1] + 1, None) 1843 ) 1844 with graph.inserting_after(node): 1845 slice_node = graph.call_function( 1846 operator.getitem, 1847 args=(unbind_input, tuple(slice_list)), 1848 ) 1849 slice_node.meta["example_value"] = torch.narrow( 1850 unbind_input.meta["example_value"], 1851 unbind_dim, 1852 getitem_indices[0], 1853 getitem_indices[-1] - getitem_indices[0] + 1, 1854 ) 1855 reshape_node = reshape_cat_node( 1856 graph, node, slice_node, cat_dim, unbind_dim, cat_shape 1857 ) 1858 new_cat_args.append(reshape_node) 1859 new_cat_args_meta.append(reshape_node.meta["example_value"]) 1860 else: 1861 # case 3: the number of getitems is smaller than the threshold, no merge is done 1862 # get the getitems based on the indexes 1863 for i in getitem_indices: 1864 new_cat_args.append(idx_to_getitems[i]) 1865 new_cat_args_meta.append(idx_to_getitems[i].meta["example_value"]) 1866 1867 1868def construct_cat_args( 1869 graph: torch.fx.Graph, 1870 cat_or_stack_node: torch.fx.Node, 1871 inputs: List[torch.fx.Node], 1872 split_or_unbind_node: torch.fx.Node, 1873 threshold_to_cat: int = 2, 1874 run_update_func: Callable = update_args_from_split_getitem, # type: ignore[type-arg] 1875) -> Tuple[List[torch.fx.Node], List[torch.Tensor]]: 1876 new_cat_args, parents_seen, getitem_indices, idx_to_getitems = [], [], [], {} # type: ignore[var-annotated] 1877 new_cat_args_meta = [] # type: ignore[var-annotated] 1878 for input in inputs: 1879 if input.target != operator.getitem: 1880 # update the last arg based on getitem_indices and parents_seens 1881 if len(parents_seen) > 0: 1882 run_update_func( # type: ignore[arg-type, union-attr] 1883 graph, 1884 cat_or_stack_node, 1885 getitem_indices, 1886 parents_seen, 1887 new_cat_args, 1888 new_cat_args_meta, 1889 idx_to_getitems, # type: ignore[arg-type, union-attr] 1890 threshold_to_cat, 1891 ) 1892 new_cat_args.append(input) 1893 new_cat_args_meta.append(input.meta["example_value"]) 1894 # reset the indices array 1895 getitem_indices, idx_to_getitems = [], {} 1896 else: 1897 # get the parent node of the getitem input 1898 parent, idx = input.args[0], input.args[1] # type: ignore[union-attr] 1899 if parent.target != split_or_unbind_node.target: # type: ignore[union-attr] 1900 new_cat_args.append(input) 1901 new_cat_args_meta.append(input.meta["example_value"]) 1902 continue 1903 # cannot use parents_seen to check since the first item could be non getitem node 1904 if len(parents_seen) == 0: 1905 parents_seen.append(parent) 1906 idx_to_getitems[idx] = input 1907 getitem_indices.append(idx) 1908 # case: we only have one getitem input, and it is in the last position 1909 if input == inputs[-1]: 1910 new_cat_args.append(input) 1911 new_cat_args_meta.append(input.meta["example_value"]) 1912 continue 1913 # if it is the last input in the tensors, we also check if it can be optimized 1914 if parent != parents_seen[-1] or input == inputs[-1]: 1915 if input == inputs[-1]: 1916 getitem_indices.append(idx) 1917 idx_to_getitems[idx] = input 1918 run_update_func( # type: ignore[arg-type, union-attr] 1919 graph, 1920 cat_or_stack_node, 1921 getitem_indices, 1922 parents_seen, 1923 new_cat_args, 1924 new_cat_args_meta, 1925 idx_to_getitems, # type: ignore[arg-type, union-attr] 1926 threshold_to_cat, 1927 ) 1928 # reset the indices array for the next parent 1929 # remember to add the last element since it is the first 1930 # item in this round of parent 1931 # add the parent to the list of seen parents 1932 parents_seen.append(parent) 1933 getitem_indices, idx_to_getitems = [idx], {idx: input} 1934 else: 1935 getitem_indices.append(idx) 1936 idx_to_getitems[idx] = input 1937 return new_cat_args, new_cat_args_meta 1938 1939 1940def remove_split_unbind_children(graph: torch.fx.Graph, inputs: List[torch.fx.Node]): 1941 nodes = set() 1942 for input in inputs: 1943 if input.target == operator.getitem: 1944 nodes.add(input.args[0]) # type: ignore[union-attr] 1945 if len(input.users.keys()) == 0: 1946 graph.erase_node(input) 1947 # check the split node to remove if it has no users 1948 for node in nodes: 1949 if len(node.users.keys()) == 0: # type: ignore[union-attr] 1950 graph.erase_node(node) # type: ignore[arg-type] 1951 1952 1953# ############pattern to be optimized is######### 1954 1955# split_node(dim=1) -> user=multiple 1956# / \ ... / \ 1957# other inputs getitem getitem getitem -> user=multiple 1958# \ / \ 1959# cat(user=mul, dim=1) other_op 1960# | 1961 1962# ################after transformation############# 1963 1964# split_node(dim=1) other inputs -> -> user=multiple 1965# / \ 1966# cat (user=mul, dim=1, split_node) 1967 1968 1969@register_graph_pattern( 1970 CallFunctionVarArgs(torch.cat, users=MULTIPLE), 1971 pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"), 1972) 1973@register_graph_pattern( 1974 CallFunction( 1975 torch.cat, 1976 getitem_split, 1977 dim=Ignored(), 1978 _users=MULTIPLE, 1979 ), 1980 pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"), 1981) 1982def split_cat_to_slices(match: Match, split_sections: List[int], dim: int): 1983 if not isinstance(split_sections, (list, tuple)): # Unnormalized split 1984 return 1985 split_nodes = [node for node in match.nodes if node.target == torch.split] 1986 if split_nodes: 1987 split_node = next(node for node in split_nodes) 1988 else: 1989 # Handle the case where there are no nodes with a target of torch.split 1990 return 1991 split_dim = get_arg_value(split_node, 2, "dim") or 0 1992 graph = match.graph 1993 threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ 1994 "split_cat_to_slices_pass" 1995 ].get("threshold_to_cat", 10) 1996 # get the cat_node and check its inputs and meta data 1997 next_users = find_next_users(split_node) 1998 for cat_node in next_users: 1999 if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): 2000 continue 2001 cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] 2002 new_cat_args, _ = construct_cat_args( 2003 graph, 2004 cat_node, 2005 cat_inputs, 2006 split_node, 2007 threshold_to_cat, 2008 update_args_from_split_getitem, 2009 ) 2010 # At least one node would be in the returned new_cat_args 2011 # case 1: if new cat args has length 1, we can remove the cat node 2012 if len(new_cat_args) == 1: 2013 cat_node.replace_all_uses_with(new_cat_args[0]) 2014 # remove inputs of cat_node if they have no users 2015 cat_inputs = cat_node.args[0] # type: ignore[union-attr] 2016 graph.erase_node(cat_node) 2017 remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] 2018 counters["inductor"]["split_cat_to_slices_pass"] += 1 2019 continue 2020 if len(new_cat_args) > 1 and len(new_cat_args) < len(cat_inputs): 2021 new_args = (new_cat_args,) 2022 with graph.inserting_after(cat_node): 2023 new_cat_node = graph.call_function( 2024 torch.cat, 2025 args=new_args, 2026 # split and cat have the same dim 2027 kwargs={"dim": split_dim}, 2028 ) 2029 cat_node.replace_all_uses_with(new_cat_node) 2030 new_cat_node.meta.update(cat_node.meta) 2031 # remove the cat node 2032 graph.erase_node(cat_node) 2033 remove_split_unbind_children(graph, cat_inputs) 2034 counters["inductor"]["split_cat_to_slices_pass"] += 1 2035 2036 2037# ############pattern to be optimized is######### 2038 2039# unbind(dim=0) -> user=multiple 2040# / \ ... / \ 2041# getitem getitem getitem getitem -> user=multiple 2042# \ / \ 2043# cat(user=mul, dim=1) other_op 2044# | 2045 2046# ################after transformation############# 2047 2048# input_of_unbind 2049# | \ 2050# slice 2051# | 2052# view 2053# | 2054 2055 2056@register_graph_pattern( 2057 CallFunction( 2058 torch.cat, 2059 getitem_unbind, 2060 dim=Ignored(), 2061 _users=MULTIPLE, 2062 ), 2063 pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"), 2064) 2065def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): 2066 unbind_node = next(node for node in match.nodes if node.target == torch.unbind) 2067 graph = match.graph 2068 # get the cat_node and check its inputs and meta data 2069 next_users = find_next_users(unbind_node) 2070 threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ 2071 "unbind_cat_to_view_pass" 2072 ].get("threshold_to_cat", 10) 2073 # get the cat_node and check its inputs and meta data 2074 for cat_node in next_users: 2075 if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): 2076 continue 2077 inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] 2078 new_cat_args, new_cat_args_meta = construct_cat_args( 2079 graph, 2080 cat_node, 2081 inputs, 2082 unbind_node, 2083 threshold_to_cat, 2084 update_args_from_unbind_getitem, 2085 ) 2086 # get the view shape 2087 # At least one node would be in the returned new_cat_args 2088 # case 1: only one node in the new cat args, don't need to cat 2089 if len(new_cat_args) == 1: 2090 cat_node.replace_all_uses_with(new_cat_args[0]) 2091 # remove inputs of cat_node if they have no users 2092 cat_inputs = cat_node.args[0] # type: ignore[union-attr] 2093 graph.erase_node(cat_node) 2094 remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] 2095 counters["inductor"]["unbind_cat_to_view_pass"] += 1 2096 continue 2097 if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): 2098 # get the view shape 2099 cat_dim = get_arg_value(cat_node, 1, "dim") 2100 with graph.inserting_after(cat_node): 2101 new_cat_node = graph.call_function( 2102 torch.cat, 2103 args=(new_cat_args,), 2104 kwargs={"dim": cat_dim}, 2105 ) 2106 new_cat_node.meta["example_value"] = torch.cat(new_cat_args_meta, dim=cat_dim) # type: ignore[arg-type] 2107 cat_node.replace_all_uses_with(new_cat_node) 2108 new_cat_node.meta.update(cat_node.meta) 2109 # remove inputs of cat_node if they have no users 2110 cat_inputs = cat_node.args[0] # type: ignore[union-attr] 2111 graph.erase_node(cat_node) 2112 remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] 2113 counters["inductor"]["unbind_cat_to_view_pass"] += 1 2114 2115 2116def reshape_cat_node_to_stack( 2117 graph: torch.fx.Graph, 2118 cat_node: torch.fx.Node, 2119 stack_node: torch.fx.Node, 2120 split_or_unbind_dim: int, 2121) -> None: 2122 # reshape the cat node to the stack node shape 2123 stack_shape = stack_node.meta["example_value"].shape 2124 stack_dim = _get_dim(stack_node) 2125 if stack_dim != split_or_unbind_dim: 2126 # case 1: the stack dim is not the same as the split dim 2127 # we need to reshape the split input before we do the reshape 2128 reshape_list = list(stack_shape) 2129 reshape_list[stack_dim], reshape_list[split_or_unbind_dim] = ( 2130 reshape_list[split_or_unbind_dim], 2131 reshape_list[stack_dim], 2132 ) 2133 reshape_node = graph.call_function( 2134 torch.reshape, 2135 args=(cat_node, tuple(reshape_list)), 2136 ) 2137 reshape_node.meta["example_value"] = torch.reshape( 2138 cat_node.meta["example_value"], tuple(reshape_list) 2139 ) 2140 permute_list = list(range(len(stack_shape))) 2141 permute_list[stack_dim], permute_list[split_or_unbind_dim] = ( 2142 permute_list[split_or_unbind_dim], 2143 permute_list[stack_dim], 2144 ) 2145 permute_node = graph.call_function( 2146 torch.permute, 2147 args=(reshape_node, permute_list), 2148 ) 2149 permute_node.meta["example_value"] = torch.permute( 2150 reshape_node.meta["example_value"], permute_list 2151 ) 2152 else: 2153 # case 2: the stack dim is the same as the split dim 2154 # we can directly reshape the split input 2155 permute_node = cat_node 2156 reshape_node = graph.call_function( 2157 torch.Tensor.view, 2158 args=(permute_node, *stack_shape), # type: ignore[arg-type] 2159 ) 2160 stack_node.replace_all_uses_with(reshape_node) 2161 reshape_node.meta.update(stack_node.meta) 2162 stack_inputs = stack_node.args[0] # type: ignore[union-attr] 2163 # remove stack node 2164 graph.erase_node(stack_node) 2165 # check the input of stack node, and remove nodes that have no users 2166 remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] 2167 2168 2169def convert_reshape_cat_arg_to_stack( 2170 graph: torch.fx.Graph, 2171 cat_node: torch.fx.Node, 2172 stack_node: torch.fx.Node, 2173 stack_node_shape: torch.Size, 2174 stack_dim: int, 2175 split_dim: int, 2176) -> torch.fx.Node: 2177 # reshape the cat node to the stack node shape 2178 cat_shape = cat_node.meta["example_value"].shape 2179 if stack_dim != split_dim: 2180 permute_list = list(range(len(cat_shape))) 2181 permute_list[stack_dim], permute_list[split_dim] = ( 2182 permute_list[split_dim], 2183 permute_list[stack_dim], 2184 ) 2185 permute_node = graph.call_function( 2186 torch.permute, 2187 args=(cat_node, permute_list), 2188 ) 2189 permute_node.meta["example_value"] = torch.permute( 2190 cat_node.meta["example_value"], permute_list 2191 ) 2192 else: 2193 permute_node = cat_node 2194 reshape_node = graph.call_function( 2195 torch.Tensor.view, 2196 args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type] 2197 ) 2198 reshape_node.meta["example_value"] = torch.Tensor.view( 2199 permute_node.meta["example_value"], tuple(stack_node_shape) # type: ignore[arg-type] 2200 ) 2201 return reshape_node 2202 2203 2204# ############pattern to be optimized is######### 2205# | | 2206# split split (dim=1) 2207# / \ / \ 2208# getitem ... getitem other ops 2209# \ | / / 2210# stack(user=mul, dim=1 or 2) -> can be different dim 2211# | 2212 2213# ################after transformation############# 2214 2215# / \ ... / \ 2216# getitem getitem getitem getitem -> user=multiple 2217# \ / 2218# cat(user=mul, dim=1) cat_other_opts 2219# \ / 2220# cat 2221# | 2222# view 2223# | 2224 2225 2226@register_graph_pattern( 2227 CallFunction( 2228 torch.stack, 2229 getitem_split, 2230 dim=Ignored(), 2231 _users=MULTIPLE, 2232 ), 2233 pass_dict=construct_pattern_matcher_pass("split_stack_to_cats_pass"), 2234) 2235def split_stack_to_cats(match: Match, split_sections: List[int], dim: int): 2236 if not isinstance(split_sections, (list, tuple)): # Unnormalized split 2237 return 2238 split_node = next(node for node in match.nodes if node.target == torch.split) 2239 split_dim = get_arg_value(split_node, 2, "dim") or 0 2240 graph = match.graph 2241 threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ 2242 "split_stack_to_cats_pass" 2243 ].get("threshold_to_cat", 10) 2244 # get the stack_node and check its inputs and meta data 2245 next_users = find_next_users(split_node) 2246 for stack_node in next_users: 2247 if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): 2248 continue 2249 inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] 2250 new_cat_args, new_cat_args_meta = construct_cat_args( 2251 graph, 2252 stack_node, 2253 inputs, 2254 split_node, 2255 threshold_to_cat, 2256 update_args_from_split_getitem, 2257 ) 2258 # At least one node would be in the returned new_cat_args 2259 # case 1: only one node in the new cat args, don't need to cat 2260 if len(new_cat_args) == 1: 2261 reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, split_dim) 2262 counters["inductor"]["split_stack_to_cats_pass"] += 1 2263 continue 2264 if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): 2265 with graph.inserting_after(stack_node): 2266 cat_node = graph.call_function( 2267 torch.cat, 2268 args=(new_cat_args,), 2269 kwargs={"dim": split_dim}, 2270 ) 2271 cat_node.meta["example_value"] = torch.cat( # type: ignore[arg-type] 2272 new_cat_args_meta, dim=split_dim 2273 ) 2274 reshape_cat_node_to_stack(graph, cat_node, stack_node, split_dim) 2275 counters["inductor"]["split_stack_to_cats_pass"] += 1 2276 2277 2278# ############pattern to be optimized is######### 2279 2280# unbind(dim=1) -> user=multiple 2281# \ ... / \ 2282# others getitem getitem getitem -> user=multiple 2283# \ \ / \ 2284# stack(user=mul, dim=1) other_op 2285# | 2286 2287# ################after transformation############# 2288 2289# input_of_unbind 2290# | \ 2291# slice 2292# | 2293# view others 2294# | / 2295# stack 2296# | 2297 2298 2299@register_graph_pattern( 2300 CallFunction( 2301 torch.stack, 2302 getitem_unbind, 2303 dim=Ignored(), 2304 _users=MULTIPLE, 2305 ), 2306 pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"), 2307) 2308def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): 2309 unbind_node = next(node for node in match.nodes if node.target == torch.unbind) 2310 graph = match.graph 2311 # get the cat_node and check its inputs and meta data 2312 next_users = find_next_users(unbind_node) 2313 threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ 2314 "unbind_stack_to_slices_pass" 2315 ].get("threshold_to_cat", 10) 2316 # get the cat_node and check its inputs and meta data 2317 for stack_node in next_users: 2318 if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): 2319 continue 2320 inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] 2321 new_cat_args, new_cat_args_meta = construct_cat_args( 2322 graph, 2323 stack_node, 2324 inputs, 2325 unbind_node, 2326 threshold_to_cat, 2327 update_args_from_unbind_getitem, 2328 ) 2329 unbind_dim = get_arg_value(unbind_node, 1, "dim") or 0 2330 # At least one node would be in the returned new_cat_args 2331 # case 1: only one node in the new cat args, don't need to cat 2332 if len(new_cat_args) == 1: 2333 reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, unbind_dim) 2334 counters["inductor"]["unbind_stack_to_slices_pass"] += 1 2335 continue 2336 if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): 2337 # get the view shape 2338 cat_dim = get_arg_value(stack_node, 1, "dim") 2339 with graph.inserting_after(stack_node): 2340 new_cat_node = graph.call_function( 2341 torch.cat, 2342 args=(new_cat_args,), 2343 kwargs={"dim": cat_dim}, 2344 ) 2345 new_cat_node.meta["example_value"] = torch.cat( 2346 new_cat_args_meta, dim=cat_dim 2347 ) 2348 reshape_cat_node_to_stack(graph, new_cat_node, stack_node, unbind_dim) 2349 counters["inductor"]["unbind_stack_to_slices_pass"] += 1 2350 2351 2352# ############pattern to be optimized is######### 2353# input 2354# | 2355# split(dim=1) -> user=multiple 2356# \ \ 2357# others getitem getitem 2358# \ \ / 2359# reshape reshape reshape other_op 2360# \ \ / / 2361# stack(user=mul, dim=0) 2362# | 2363 2364# ################after transformation############# 2365# input 2366# | 2367# permute 2368# | 2369# reshape others 2370# | / 2371# cat (dim=0) 2372# | 2373 2374 2375def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> List[int]: 2376 # cat_arg must be the split input 2377 view_shape_list = [] 2378 for user in cat_arg.users.keys(): 2379 if user.target == torch.split: 2380 for getitem in user.users.keys(): 2381 if getitem.target == operator.getitem: 2382 reshape_user = [ 2383 user 2384 for user in getitem.users.keys() 2385 if user.target == torch.reshape 2386 ] 2387 if len(reshape_user) > 0: 2388 view_shape_list = list( 2389 reshape_user[0] 2390 .meta["example_value"] 2391 .unsqueeze(stack_dim) 2392 .shape 2393 ) 2394 view_shape_list[stack_dim] = -1 2395 return view_shape_list 2396 return view_shape_list 2397 2398 2399@register_graph_pattern( 2400 CallFunction( 2401 torch.stack, 2402 reshape_getitem_split, 2403 dim=Ignored(), 2404 _users=MULTIPLE, 2405 ), 2406 pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"), 2407) 2408def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): 2409 split_node = next(node for node in match.nodes if node.target == torch.split) 2410 split_dim = _get_dim(split_node) 2411 split_users = list(split_node.users.keys()) 2412 stack_nodes = [node for node in match.nodes if node.target == torch.stack] 2413 graph = match.graph 2414 threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ 2415 "move_reshape_out_of_split_stack_pass" 2416 ].get("threshold_to_cat", 10) 2417 for stack_node in stack_nodes: 2418 if not is_node_meta_valid(stack_node): 2419 log.debug("example value absent for node: %s", stack_node) 2420 continue 2421 stack_dim = _get_dim(stack_node) 2422 stack_inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] 2423 inputs = [] 2424 for stack_input in stack_inputs: 2425 if stack_input.target != torch.reshape: 2426 inputs.append(stack_input) 2427 else: 2428 inputs.append(stack_input.args[0]) # type: ignore[union-attr] 2429 new_cat_args, new_cat_args_meta = construct_cat_args( 2430 graph, 2431 stack_node, 2432 inputs, 2433 split_node, 2434 threshold_to_cat, 2435 update_args_from_split_getitem, 2436 ) 2437 # At least one node would be in the returned new_cat_args 2438 # case 1: only one node in the new cat args, don't need to cat 2439 if len(new_cat_args) == 1: 2440 reshape_node = convert_reshape_cat_arg_to_stack( 2441 graph, 2442 new_cat_args[0], 2443 stack_node, 2444 stack_node.meta["example_value"].shape, 2445 stack_dim, 2446 split_dim, 2447 ) 2448 stack_node.replace_all_uses_with(reshape_node) 2449 # remove stack node 2450 graph.erase_node(stack_node) 2451 # check the input of stack node, and remove nodes that have no users 2452 remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] 2453 remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] 2454 counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 2455 continue 2456 if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): 2457 # decompose the cat args into multiple stack nodes, i.e., we stack 2458 # all the nodes exist in the stack inputs and reshape the rest followed by a cat 2459 stack_node_input, stack_node_input_meta, cat_inputs = [], [], [] # type: ignore[var-annotated] 2460 for cat_arg in new_cat_args: 2461 if cat_arg not in stack_inputs: 2462 if len(stack_node_input) > 0: 2463 with graph.inserting_after(stack_node): 2464 decomposed_stack_node = graph.call_function( 2465 torch.stack, 2466 args=(stack_node_input,), 2467 kwargs={"dim": stack_dim}, 2468 ) 2469 decomposed_stack_node.meta["example_value"] = torch.stack( 2470 stack_node_input_meta, dim=stack_dim 2471 ) 2472 cat_inputs.append(decomposed_stack_node) 2473 # cat_arg must be the split input 2474 view_shape_list = get_view_shape_list(cat_arg, stack_dim) 2475 stack_node_shape = torch.reshape(cat_arg.meta["example_value"], tuple(view_shape_list)).shape # type: ignore[union-attr] 2476 cat_inputs.append( 2477 convert_reshape_cat_arg_to_stack( 2478 graph, 2479 cat_arg, 2480 stack_node, 2481 stack_node_shape, 2482 stack_dim, 2483 split_dim, 2484 ) 2485 ) 2486 stack_node_input, stack_node_input_meta = [], [] 2487 else: 2488 stack_node_input.append(cat_arg) 2489 stack_node_input_meta.append(cat_arg.meta["example_value"]) 2490 2491 if len(stack_node_input) > 0: 2492 with graph.inserting_after(stack_node): 2493 decomposed_stack_node = graph.call_function( 2494 torch.stack, 2495 args=(stack_node_input,), 2496 kwargs={"dim": stack_dim}, 2497 ) 2498 decomposed_stack_node.meta["example_value"] = torch.stack( 2499 stack_node_input_meta, dim=stack_dim 2500 ) 2501 cat_inputs.append(decomposed_stack_node) 2502 2503 with graph.inserting_after(stack_node): 2504 cat_node = graph.call_function( 2505 torch.cat, 2506 args=(cat_inputs,), 2507 kwargs={"dim": stack_dim}, 2508 ) 2509 stack_node.replace_all_uses_with(cat_node) 2510 cat_node.meta.update(stack_node.meta) 2511 graph.erase_node(stack_node) 2512 remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] 2513 remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] 2514 counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 2515