1# mypy: allow-untyped-defs 2import copy 3import itertools 4import logging 5from typing import Dict, Optional 6 7import torch 8import torch.nn as nn 9from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log 10from torch._utils_internal import upload_graph 11from torch.fx.experimental.optimization import ( 12 matches_module_pattern, 13 replace_node_module, 14) 15from torch.fx.passes.graph_transform_observer import GraphTransformObserver 16from torch.fx.passes.shape_prop import ShapeProp 17from torch.nn import functional as F 18from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights 19 20from .. import config 21from ..fx_utils import matches_module_function_pattern 22from ..pattern_matcher import ( 23 init_once_fakemode, 24 PatternMatcherPass, 25 stable_topological_sort, 26) 27from ..utils import is_cpu_device, pass_execution_and_save 28from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS 29from .misc_patterns import numpy_compat_normalization 30from .split_cat import PRE_GRAD_PATTERNS 31 32 33log = logging.getLogger(__name__) 34 35efficient_conv_bn_eval_pass = PatternMatcherPass( 36 pass_name="efficient_conv_bn_eval_pass" 37) 38 39fuse_split_linear_add_pass = PatternMatcherPass( 40 pass_name="fuse_split_linear_add_pass", 41) 42fuse_chunk_squeeze_cat_pass = PatternMatcherPass( 43 pass_name="fuse_chunk_squeeze_cat_pass", 44) 45remove_reshape_pass = PatternMatcherPass( 46 pass_name="remove_reshape_pass", 47) 48 49# based on predispatch aten IR 50normalization_pass_aten = PatternMatcherPass() 51merge_splits_pass_aten = PatternMatcherPass() 52split_cat_pass_aten = PatternMatcherPass() 53unbind_stack_pass_aten = PatternMatcherPass() 54merge_getitem_cat_pass_aten = PatternMatcherPass() 55merge_stack_tahn_unbind_pass_aten = PatternMatcherPass() 56mutate_cat_pass_aten = PatternMatcherPass() 57remove_split_with_size_one_pass_aten = PatternMatcherPass() 58 59 60def save_inductor_dict(pass_to_compare=None): 61 if not pass_to_compare: 62 pass_to_compare = list(config.pre_grad_fusion_options.keys()) + list( 63 config.post_grad_fusion_options.keys() 64 ) 65 return {p: dict(counters["inductor"]).get(p, 0) for p in pass_to_compare} 66 67 68def is_same_dict(inductor_dict, optimus_dict): 69 for pass_name, count in optimus_dict.items(): 70 if count != dict(inductor_dict).get(pass_name, 0): 71 return False 72 return True 73 74 75def normalize_node_kwargs_pass(graph): 76 return None 77 78 79def fuse_parallel_linear_pass(graph): 80 return None 81 82 83def remove_split_ops(graph, shape_prop): 84 return None 85 86 87def fuse_chunk_reshape_unsqueeze_concat_pass(graph): 88 return None 89 90 91def fuse_chunk_reshape_concat_pass(graph): 92 return None 93 94 95def remove_noop_pass(graph): 96 return None 97 98 99def stack_to_unsqueeze_pass(graph): 100 return None 101 102 103@init_once_fakemode 104def lazy_init(): 105 from . import efficient_conv_bn_eval, split_cat # noqa: F401 # noqa: F401 106 107 if config.is_fbcode(): 108 from . import fb # type: ignore[attr-defined] # noqa: F401 109 110 111def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None): 112 """ 113 Apply passes on the input FX graph using Torch IR. 114 115 WARNING: 116 The IR before grad is not functional or normalized, so it is harder 117 to write passes on this IR. Passes must be safe with respect to 118 aliasing and mutation and need to handle all possible arg schemas. 119 120 Consider adding a new pass to post_grad.py or joint_graph.py which 121 are after functionalization and normalization. 122 """ 123 if config.pattern_matcher: 124 lazy_init() 125 if hasattr( 126 config, "fx_passes_numeric_check" 127 ) and config.fx_passes_numeric_check.get("pre_grad", False): 128 gm_before_fx_passes = gm.__copy__() 129 # explicitly run with predispatch atenIR based passes 130 if config.is_predispatch: 131 132 def shape_prop(mod) -> None: 133 ShapeProp( 134 gm=mod, 135 # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` 136 fake_mode=detect_fake_mode(example_inputs), 137 ).propagate(*example_inputs) 138 139 # normalization pass 140 pass_execution_and_save( 141 normalization_pass_aten.apply, 142 gm, 143 example_inputs, 144 "[Pre grad(predispatch IR)]Apply normalization pass", 145 ) 146 # normalize kwargs, must be called as the first pass 147 pass_execution_and_save( 148 normalize_node_kwargs_pass, 149 gm, 150 example_inputs, 151 "[Pre grad(predispatch IR)]Apply normalize_node_kwargs_pass", 152 ) 153 pass_execution_and_save( 154 remove_noop_pass, 155 gm, 156 example_inputs, 157 "[Pre grad(predispatch IR)]Apply remove_noop pass", 158 ) 159 pass_execution_and_save( 160 fuse_chunk_reshape_concat_pass, 161 gm, 162 example_inputs, 163 "[Pre grad(predispatch IR)] Apply fuse_chunk_reshape_concat_pass", 164 ) 165 pass_execution_and_save( 166 group_batch_fusion_passes, 167 gm, 168 example_inputs, 169 "[Pre grad(predispatch IR)] Apply group_batch_fusion", 170 ) 171 pass_execution_and_save( 172 normalize_node_kwargs_pass, 173 gm, 174 example_inputs, 175 "[Pre grad(predispatch IR)]Apply normalize_node_kwargs_pass", 176 ) 177 pass_execution_and_save( 178 fuse_chunk_squeeze_cat_pass.apply, 179 gm, 180 example_inputs, 181 "[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass", 182 ) 183 pass_execution_and_save( 184 fuse_split_linear_add_pass.apply, 185 gm, 186 example_inputs, 187 "[Pre grad(predispatch IR)] Apply fuse_split_linear_add_pass", 188 ) 189 pass_execution_and_save( 190 remove_reshape_pass.apply, 191 gm, 192 example_inputs, 193 "[Pre grad(predispatch IR)] Apply remove_reshape_pass", 194 ) 195 pass_execution_and_save( 196 fuse_parallel_linear_pass, 197 gm, 198 example_inputs, 199 "[Pre grad(predispatch IR)] Apply fuse_parallel_linear_pass", 200 ) 201 pass_execution_and_save( 202 lambda graph: remove_split_ops(graph.owning_module, shape_prop), 203 gm, 204 example_inputs, 205 "[Pre grad(predispatch IR)] Apply remove_split_ops", 206 ) 207 # run before fuse_chunk_reshape_unsqueeze_concat_pass 208 pass_execution_and_save( 209 stack_to_unsqueeze_pass, 210 gm, 211 example_inputs, 212 "[Pre grad(predispatch IR)] Apply stack_to_unsqueeze_pass", 213 ) 214 pass_execution_and_save( 215 fuse_chunk_reshape_unsqueeze_concat_pass, 216 gm, 217 example_inputs, 218 "[Pre grad(predispatch IR)] Apply fuse_chunk_reshape_unsqueeze_concat_pass", 219 ) 220 # Remove noops at the end, which may be generated other passes. 221 pass_execution_and_save( 222 remove_noop_pass, 223 gm, 224 example_inputs, 225 "[Pre grad(predispatch IR)]Apply remove_noop pass", 226 ) 227 shape_prop(gm) 228 229 else: 230 # We only log the graph with changes to avoid the excessive compilation time 231 # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/ 232 if example_inputs is not None: 233 gm = fuse_fx(gm, example_inputs) 234 numpy_compat_normalization(gm.graph) 235 optimus_scuba_log["before_recompile_pre_grad"] = upload_graph(gm.graph) 236 group_batch_fusion_passes(gm.graph, pre_grad=True) 237 for pass_name in config.pre_grad_fusion_options: 238 # skip all patterns for group batch fusions 239 if pass_name in PRE_GRAD_FUSIONS: 240 continue 241 pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name] 242 inductor_before_change = save_inductor_dict( 243 [pattern_matcher_pass.pass_name] 244 ) 245 # we support run same pattern multiple times, the default is to run only once 246 counter = config.pre_grad_fusion_options[pass_name].get("counter", 1) 247 for _ in range(counter): 248 pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] 249 if not is_same_dict(counters["inductor"], inductor_before_change): 250 optimus_scuba_log[ 251 f"{pattern_matcher_pass.pass_name}_pre_grad" 252 ] = upload_graph(gm.graph) 253 # TODO: move efficient_conv_bn_eval_pass to the fusions dict too. 254 efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type] 255 256 if config.pre_grad_custom_pass is not None: 257 with GraphTransformObserver( 258 gm, "pre_grad_custom_pass", config.trace.log_url_for_graph_xform 259 ): 260 config.pre_grad_custom_pass(gm.graph) 261 stable_topological_sort(gm.graph) 262 263 from .quantization import quant_lift_up 264 265 quant_lift_up(gm) 266 267 gm.graph.lint() 268 gm.recompile() 269 optimus_scuba_log["after_recompile_pre_grad"] = upload_graph(gm.graph) 270 271 if ( 272 config.pattern_matcher 273 and hasattr(config, "fx_passes_numeric_check") 274 and config.fx_passes_numeric_check.get("pre_grad", False) 275 and example_inputs is not None 276 ): 277 from .numeric_utils import numeric_check_if_enabled 278 279 gm_after_fx_passes = gm.__copy__() 280 numeric_check_if_enabled( 281 gm_before_fx_passes, # type: ignore[possibly-undefined] 282 gm_after_fx_passes, 283 example_inputs, 284 config.fx_passes_numeric_check.get("num_iterations", 1), 285 config.fx_passes_numeric_check.get("precision", 1e-4), 286 ) 287 288 return gm 289 290 291def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule: 292 is_cpu = is_cpu_device(example_inputs) 293 # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` 294 fake_mode = detect_fake_mode(example_inputs) 295 296 gm = sink_cat_after_pointwise(gm) 297 if config.permute_fusion and not is_cpu: 298 # For linear permute fusion, we need to check input info to identify 299 # and perform proper permutation/transpose 300 ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) 301 with GraphTransformObserver( 302 gm, "linear_permute_fusion", config.trace.log_url_for_graph_xform 303 ): 304 gm = linear_permute_fusion(gm) 305 with GraphTransformObserver( 306 gm, "permute_linear_fusion", config.trace.log_url_for_graph_xform 307 ): 308 gm = permute_linear_fusion(gm) 309 with GraphTransformObserver( 310 gm, "permute_matmul_fusion", config.trace.log_url_for_graph_xform 311 ): 312 gm = permute_matmul_fusion(gm) 313 314 # make sure the autograd is disabled. 315 if torch.is_grad_enabled() or not is_cpu: 316 return gm 317 if config.freezing: 318 with GraphTransformObserver( 319 gm, "remove_identity", config.trace.log_url_for_graph_xform 320 ): 321 gm = remove_identity(gm) 322 with GraphTransformObserver( 323 gm, "fuse_conv_bn", config.trace.log_url_for_graph_xform 324 ): 325 gm = fuse_conv_bn(gm) 326 return gm 327 328 329def fetch_attr(target: str, mod): 330 target_atoms = target.split(".") 331 attr_itr = mod 332 for i, atom in enumerate(target_atoms): 333 if not hasattr(attr_itr, atom): 334 raise RuntimeError( 335 f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" 336 ) 337 attr_itr = getattr(attr_itr, atom) 338 return attr_itr 339 340 341def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 342 """ 343 Removes all identity layers from the module. 344 """ 345 346 class IdentityRemover(torch.fx.Transformer): 347 def call_module(self, target, args, kwargs): 348 if isinstance(self.submodules[target], nn.Identity): 349 assert len(args) == 1 350 return args[0] 351 else: 352 return super().call_module(target, args, kwargs) 353 354 return IdentityRemover(gm).transform() 355 356 357def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule: 358 """ 359 Fuses Convolution/BN layers for inference purposes. 360 """ 361 modules_patterns = [ 362 (torch.nn.Conv1d, torch.nn.BatchNorm1d), 363 (torch.nn.Conv2d, torch.nn.BatchNorm2d), 364 (torch.nn.Conv3d, torch.nn.BatchNorm3d), 365 ] 366 module_function_patterns = [ 367 (torch.nn.Conv1d, F.batch_norm), 368 (torch.nn.Conv2d, F.batch_norm), 369 (torch.nn.Conv3d, F.batch_norm), 370 ] 371 modules = dict(gm.named_modules()) 372 373 class ConvBNFusion: 374 def __init__( 375 self, 376 bn_node, 377 conv_module, 378 bn_module=None, # For BN Module 379 bn_running_mean=None, # For Functional BN 380 bn_running_var=None, 381 bn_eps=None, 382 bn_weight=None, 383 bn_bias=None, 384 ) -> None: 385 self.bn_nodes = [ 386 bn_node, 387 ] 388 self.conv_module = conv_module 389 self.bn_module = bn_module 390 self.bn_running_mean = bn_running_mean 391 self.bn_running_var = bn_running_var 392 self.bn_eps = bn_eps 393 self.bn_weight = bn_weight 394 self.bn_bias = bn_bias 395 self.fusion_enabled = True 396 397 def add_bn_node(self, bn_node): 398 self.bn_nodes.append(bn_node) 399 400 def disable_fusion(self): 401 self.fusion_enabled = False 402 403 def is_fusion_enabled(self): 404 return self.fusion_enabled 405 406 conv_bn_to_fuse: Dict[int, ConvBNFusion] = {} 407 for pattern in modules_patterns: 408 conv_bn_to_fuse.clear() 409 for node in gm.graph.nodes: 410 if matches_module_pattern(pattern, node, modules): 411 if len(node.args[0].users) > 1: # Output of conv is used by other nodes 412 continue 413 conv = modules[node.args[0].target] 414 bn = modules[node.target] 415 eval_mode = all(not n.training for n in [conv, bn]) 416 if not eval_mode: 417 continue 418 if not bn.track_running_stats: 419 continue 420 421 # Do hash based on the module name of conv 422 hash_id = hash(node.args[0].target) 423 if hash_id not in conv_bn_to_fuse: 424 conv_bn_to_fuse[hash_id] = ConvBNFusion(node, conv, bn) 425 else: 426 if bn == conv_bn_to_fuse[hash_id].bn_module: 427 # Do fusion if same bn module 428 conv_bn_to_fuse[hash_id].add_bn_node(node) 429 else: 430 # Disable the conv bn folding if conv shared by different bn 431 conv_bn_to_fuse[hash_id].disable_fusion() 432 433 for conv_bn_fusion in conv_bn_to_fuse.values(): 434 if conv_bn_fusion.is_fusion_enabled(): 435 bn_nodes = conv_bn_fusion.bn_nodes 436 conv = conv_bn_fusion.conv_module 437 bn = conv_bn_fusion.bn_module 438 439 fused_conv = fuse_conv_bn_eval(conv, bn) 440 for bn_node in bn_nodes: 441 replace_node_module(bn_node.args[0], modules, fused_conv) 442 bn_node.replace_all_uses_with(bn_node.args[0]) 443 gm.graph.erase_node(bn_node) 444 445 gm.graph.lint() 446 for pattern in module_function_patterns: 447 conv_bn_to_fuse.clear() 448 for node in gm.graph.nodes: 449 if matches_module_function_pattern(pattern, node, modules): 450 # TODO: support kwargs. 451 if len(node.args) != 8: 452 continue 453 conv = modules[node.args[0].target] 454 bn_training = node.args[5] 455 bn_eps = node.args[7] 456 if conv.training or bn_training: 457 continue 458 if type(bn_eps) is not float: 459 continue 460 461 def _used_by_same_conv_module(users): 462 conv_module_name = users[0].args[0].target 463 return all( 464 conv_module_name == user.args[0].target for user in users 465 ) 466 467 bn_args_is_constant = all( 468 n.op == "get_attr" 469 and (len(n.users) == 1 or _used_by_same_conv_module(list(n.users))) 470 for n in node.args[1:5] 471 ) 472 if not bn_args_is_constant: 473 continue 474 bn_running_mean = fetch_attr(node.args[1].target, gm) 475 bn_running_var = fetch_attr(node.args[2].target, gm) 476 bn_weight = fetch_attr(node.args[3].target, gm) 477 bn_bias = fetch_attr(node.args[4].target, gm) 478 if bn_running_mean is None or bn_running_var is None: 479 continue 480 481 # Do hash based on the module name of conv 482 hash_id = hash(node.args[0].target) 483 if hash_id not in conv_bn_to_fuse: 484 conv_bn_to_fuse[hash_id] = ConvBNFusion( 485 node, 486 conv, 487 bn_running_mean=bn_running_mean, 488 bn_running_var=bn_running_var, 489 bn_eps=bn_eps, 490 bn_weight=bn_weight, 491 bn_bias=bn_bias, 492 ) 493 else: 494 if ( 495 hash(bn_running_mean) 496 == hash(conv_bn_to_fuse[hash_id].bn_running_mean) 497 and hash(bn_running_var) 498 == hash(conv_bn_to_fuse[hash_id].bn_running_var) 499 and torch.allclose( 500 torch.tensor(bn_eps), 501 torch.tensor(conv_bn_to_fuse[hash_id].bn_eps), 502 ) 503 and hash(bn_weight) == hash(conv_bn_to_fuse[hash_id].bn_weight) 504 and hash(bn_bias) == hash(conv_bn_to_fuse[hash_id].bn_bias) 505 ): 506 # Do fusion if same functional bn 507 conv_bn_to_fuse[hash_id].add_bn_node(node) 508 else: 509 # Disable the conv bn folding if conv shared by different bn 510 conv_bn_to_fuse[hash_id].disable_fusion() 511 512 for conv_bn_fusion in conv_bn_to_fuse.values(): 513 if conv_bn_fusion.is_fusion_enabled(): 514 bn_nodes = conv_bn_fusion.bn_nodes 515 conv = conv_bn_fusion.conv_module 516 bn_running_mean = conv_bn_fusion.bn_running_mean 517 bn_running_var = conv_bn_fusion.bn_running_var 518 bn_eps = conv_bn_fusion.bn_eps 519 bn_weight = conv_bn_fusion.bn_weight 520 bn_bias = conv_bn_fusion.bn_bias 521 522 fused_conv = copy.deepcopy(conv) 523 fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( 524 fused_conv.weight, 525 fused_conv.bias, 526 bn_running_mean, 527 bn_running_var, 528 bn_eps, 529 bn_weight, 530 bn_bias, 531 ) 532 for bn_node in bn_nodes: 533 replace_node_module(bn_node.args[0], modules, fused_conv) 534 bn_node.replace_all_uses_with(bn_node.args[0]) 535 gm.graph.erase_node(bn_node) 536 gm.graph.lint() 537 gm.recompile() 538 539 return gm 540 541 542class NormalizedLinearNode: 543 def __init__(self, node: torch.fx.Node) -> None: 544 assert node.op == "call_function" 545 assert node.target in [torch.nn.functional.linear] 546 self.node: torch.fx.Node = node 547 548 def get_input(self) -> torch.fx.Node: 549 if len(self.node.args) > 0: 550 return self.node.args[0] # type: ignore[return-value] 551 else: 552 return self.node.kwargs["input"] # type: ignore[return-value] 553 554 def get_weight(self) -> torch.fx.Node: 555 if len(self.node.args) > 1: 556 return self.node.args[1] # type: ignore[return-value] 557 else: 558 return self.node.kwargs["weight"] # type: ignore[return-value] 559 560 def get_bias(self) -> torch.fx.Node: 561 if len(self.node.args) > 2: 562 return self.node.args[2] # type: ignore[return-value] 563 else: 564 return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value] 565 566 567class NormalizedMatmulNode: 568 def __init__(self, node: torch.fx.Node) -> None: 569 assert node.op == "call_function" 570 assert node.target in [torch.bmm, torch.matmul] 571 self.node: torch.fx.Node = node 572 573 def get_input(self) -> torch.fx.Node: 574 if len(self.node.args) > 0: 575 return self.node.args[0] # type: ignore[return-value] 576 else: 577 return self.node.kwargs["input"] # type: ignore[return-value] 578 579 def get_other(self) -> torch.fx.Node: 580 if len(self.node.args) > 1: 581 return self.node.args[1] # type: ignore[return-value] 582 else: 583 return self.node.kwargs["other"] # type: ignore[return-value] 584 585 586def check_permute(node: torch.fx.Node) -> bool: 587 ranks = len(node.meta["tensor_meta"].shape) 588 if len(node.args) > 3: 589 permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator] 590 elif ( 591 "permutation" in node.kwargs 592 and node.kwargs["permutation"] is not None 593 and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type] 594 ): 595 permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[union-attr] 596 else: 597 return False 598 allowed_permutation = list(range(ranks)) 599 allowed_permutation[-1] = ranks - 2 600 allowed_permutation[-2] = ranks - 1 601 return permutation == allowed_permutation 602 603 604def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule: 605 def one_user(node): 606 users = list(node.users) 607 return users[0] if len(users) == 1 else None 608 609 def is_view(node): 610 view = {"view"} 611 return node.op == "call_method" and node.target in view 612 613 def is_pointwise_unary(node): 614 pointwise = {torch.relu, torch.tanh, "relu", "tanh"} 615 return node.op in {"call_function", "call_method"} and node.target in pointwise 616 617 g = module.graph 618 for node in g.nodes: 619 if node.op != "call_function" or node.target != torch.cat: 620 continue 621 622 cat_or_view = node 623 while True: 624 user = one_user(cat_or_view) 625 if not user or not is_view(user): 626 break 627 cat_or_view = user 628 629 if user and is_pointwise_unary(user): 630 with g.inserting_before(node): 631 632 def cat_args(tensors, dim=0): 633 return tensors, dim 634 635 tensors, dim = cat_args(*node.args, **node.kwargs) 636 new_kwargs = { 637 name: val for name, val in user.kwargs.items() if name != "input" 638 } 639 new_tensors = [ 640 g.create_node(user.op, user.target, args=(arg,), kwargs=new_kwargs) 641 for arg in tensors 642 ] 643 new_cat = g.create_node( 644 "call_function", torch.cat, args=(new_tensors, dim) 645 ) 646 user.replace_all_uses_with(cat_or_view) 647 node.replace_all_uses_with(new_cat) 648 g.erase_node(user) 649 g.erase_node(node) 650 g.lint() 651 module.recompile() 652 return module 653 654 655def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: 656 for node in module.graph.find_nodes(op="call_method", target="permute"): 657 if check_permute(node): 658 if len(node.args) > 0: 659 input_node = node.args[0] 660 else: 661 input_node = node.kwargs["input"] 662 if ( 663 input_node.op == "call_function" 664 and input_node.target == torch.nn.functional.linear 665 ): 666 normalized = NormalizedLinearNode(input_node) 667 input = normalized.get_input() 668 weight = normalized.get_weight() 669 bias = normalized.get_bias() 670 with module.graph.inserting_before(node): 671 fused_node = module.graph.call_function( 672 linear_transpose, args=(input, weight, bias) 673 ) 674 node.replace_all_uses_with(fused_node) 675 module.graph.erase_node(node) 676 if len(input_node.users) == 0: 677 module.graph.erase_node(input_node) 678 679 module.graph.lint() 680 module.recompile() 681 return module 682 683 684# Y1 = X * W^T + bias 685# Y2 = Y1.permute(0, 2, 1) 686# ----> 687# Y2 = (W * X^T + bias.unsqueeze(-1))^T 688def linear_transpose( 689 input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] 690) -> torch.Tensor: 691 if bias is None: 692 return torch.matmul(weight, input.transpose(-1, -2)) 693 return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) 694 695 696def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: 697 for node in module.graph.find_nodes( 698 op="call_function", target=torch.nn.functional.linear 699 ): 700 if len(node.args) > 0: 701 input_node = node.args[0] 702 else: 703 input_node = node.kwargs["input"] 704 if ( 705 input_node.op == "call_method" 706 and input_node.target == "permute" 707 and check_permute(input_node) 708 ): 709 normalized = NormalizedLinearNode(node) 710 if len(input_node.args) > 0: 711 input = input_node.args[0] 712 else: 713 input = input_node.kwargs["input"] 714 weight = normalized.get_weight() 715 bias = normalized.get_bias() 716 with module.graph.inserting_before(node): 717 fused_node = module.graph.call_function( 718 transpose_linear, args=(input, weight, bias) 719 ) 720 node.replace_all_uses_with(fused_node) 721 module.graph.erase_node(node) 722 if len(input_node.users) == 0: 723 module.graph.erase_node(input_node) 724 725 module.graph.lint() 726 module.recompile() 727 return module 728 729 730def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: 731 for node in itertools.chain( 732 module.graph.find_nodes(op="call_function", target=torch.bmm), 733 module.graph.find_nodes(op="call_function", target=torch.matmul), 734 ): 735 normalized = NormalizedMatmulNode(node) 736 input_A_node = normalized.get_input() 737 input_B_node = normalized.get_other() 738 input_A = input_A_node 739 input_B = input_B_node 740 Atrans = Btrans = False 741 if ( 742 input_A_node.op == "call_method" 743 and input_A_node.target == "permute" 744 and check_permute(input_A_node) 745 ): 746 Atrans = True 747 if len(input_A_node.args) > 0: 748 input_A = input_A_node.args[0] # type: ignore[assignment] 749 else: 750 input_A = input_A_node.kwargs["input"] # type: ignore[assignment] 751 752 if ( 753 input_B_node.op == "call_method" 754 and input_B_node.target == "permute" 755 and check_permute(input_B_node) 756 ): 757 Btrans = True 758 if len(input_B_node.args) > 0: 759 input_B = input_B_node.args[0] # type: ignore[assignment] 760 else: 761 input_B = input_B_node.kwargs["input"] # type: ignore[assignment] 762 763 if Atrans or Btrans: 764 with module.graph.inserting_before(node): 765 fused_node = module.graph.call_function( 766 transpose_matmul, 767 args=(input_A, input_B, Atrans, Btrans), 768 ) 769 node.replace_all_uses_with(fused_node) 770 module.graph.erase_node(node) 771 if Atrans and len(input_A_node.users) == 0: 772 module.graph.erase_node(input_A_node) 773 if Btrans and len(input_B_node.users) == 0: 774 module.graph.erase_node(input_B_node) 775 776 module.graph.lint() 777 module.recompile() 778 return module 779 780 781# X1 = X.permute(0, 2, 1) 782# Y1 = X1 * W1^T + bias1 783# ----> 784# Y2 = X1.transpose(-1, -2) * W1^T + bias1 785def transpose_linear( 786 input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] 787) -> torch.Tensor: 788 if bias is None: 789 return torch.matmul(input.transpose(-1, -2), weight.t()) 790 return torch.matmul(input.transpose(-1, -2), weight.t()) + bias 791 792 793def transpose_matmul( 794 A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool 795) -> torch.Tensor: 796 if Atrans: 797 A = A.transpose(-1, -2) 798 if Btrans: 799 B = B.transpose(-1, -2) 800 return torch.matmul(A, B) 801