xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/pre_grad.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import itertools
4import logging
5from typing import Dict, Optional
6
7import torch
8import torch.nn as nn
9from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log
10from torch._utils_internal import upload_graph
11from torch.fx.experimental.optimization import (
12    matches_module_pattern,
13    replace_node_module,
14)
15from torch.fx.passes.graph_transform_observer import GraphTransformObserver
16from torch.fx.passes.shape_prop import ShapeProp
17from torch.nn import functional as F
18from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
19
20from .. import config
21from ..fx_utils import matches_module_function_pattern
22from ..pattern_matcher import (
23    init_once_fakemode,
24    PatternMatcherPass,
25    stable_topological_sort,
26)
27from ..utils import is_cpu_device, pass_execution_and_save
28from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS
29from .misc_patterns import numpy_compat_normalization
30from .split_cat import PRE_GRAD_PATTERNS
31
32
33log = logging.getLogger(__name__)
34
35efficient_conv_bn_eval_pass = PatternMatcherPass(
36    pass_name="efficient_conv_bn_eval_pass"
37)
38
39fuse_split_linear_add_pass = PatternMatcherPass(
40    pass_name="fuse_split_linear_add_pass",
41)
42fuse_chunk_squeeze_cat_pass = PatternMatcherPass(
43    pass_name="fuse_chunk_squeeze_cat_pass",
44)
45remove_reshape_pass = PatternMatcherPass(
46    pass_name="remove_reshape_pass",
47)
48
49# based on predispatch aten IR
50normalization_pass_aten = PatternMatcherPass()
51merge_splits_pass_aten = PatternMatcherPass()
52split_cat_pass_aten = PatternMatcherPass()
53unbind_stack_pass_aten = PatternMatcherPass()
54merge_getitem_cat_pass_aten = PatternMatcherPass()
55merge_stack_tahn_unbind_pass_aten = PatternMatcherPass()
56mutate_cat_pass_aten = PatternMatcherPass()
57remove_split_with_size_one_pass_aten = PatternMatcherPass()
58
59
60def save_inductor_dict(pass_to_compare=None):
61    if not pass_to_compare:
62        pass_to_compare = list(config.pre_grad_fusion_options.keys()) + list(
63            config.post_grad_fusion_options.keys()
64        )
65    return {p: dict(counters["inductor"]).get(p, 0) for p in pass_to_compare}
66
67
68def is_same_dict(inductor_dict, optimus_dict):
69    for pass_name, count in optimus_dict.items():
70        if count != dict(inductor_dict).get(pass_name, 0):
71            return False
72    return True
73
74
75def normalize_node_kwargs_pass(graph):
76    return None
77
78
79def fuse_parallel_linear_pass(graph):
80    return None
81
82
83def remove_split_ops(graph, shape_prop):
84    return None
85
86
87def fuse_chunk_reshape_unsqueeze_concat_pass(graph):
88    return None
89
90
91def fuse_chunk_reshape_concat_pass(graph):
92    return None
93
94
95def remove_noop_pass(graph):
96    return None
97
98
99def stack_to_unsqueeze_pass(graph):
100    return None
101
102
103@init_once_fakemode
104def lazy_init():
105    from . import efficient_conv_bn_eval, split_cat  # noqa: F401  # noqa: F401
106
107    if config.is_fbcode():
108        from . import fb  # type: ignore[attr-defined]  # noqa: F401
109
110
111def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None):
112    """
113    Apply passes on the input FX graph using Torch IR.
114
115    WARNING:
116    The IR before grad is not functional or normalized, so it is harder
117    to write passes on this IR.  Passes must be safe with respect to
118    aliasing and mutation and need to handle all possible arg schemas.
119
120    Consider adding a new pass to post_grad.py or joint_graph.py which
121    are after functionalization and normalization.
122    """
123    if config.pattern_matcher:
124        lazy_init()
125        if hasattr(
126            config, "fx_passes_numeric_check"
127        ) and config.fx_passes_numeric_check.get("pre_grad", False):
128            gm_before_fx_passes = gm.__copy__()
129        # explicitly run with predispatch atenIR based passes
130        if config.is_predispatch:
131
132            def shape_prop(mod) -> None:
133                ShapeProp(
134                    gm=mod,
135                    # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
136                    fake_mode=detect_fake_mode(example_inputs),
137                ).propagate(*example_inputs)
138
139            # normalization pass
140            pass_execution_and_save(
141                normalization_pass_aten.apply,
142                gm,
143                example_inputs,
144                "[Pre grad(predispatch IR)]Apply normalization pass",
145            )
146            # normalize kwargs, must be called as the first pass
147            pass_execution_and_save(
148                normalize_node_kwargs_pass,
149                gm,
150                example_inputs,
151                "[Pre grad(predispatch IR)]Apply normalize_node_kwargs_pass",
152            )
153            pass_execution_and_save(
154                remove_noop_pass,
155                gm,
156                example_inputs,
157                "[Pre grad(predispatch IR)]Apply remove_noop pass",
158            )
159            pass_execution_and_save(
160                fuse_chunk_reshape_concat_pass,
161                gm,
162                example_inputs,
163                "[Pre grad(predispatch IR)] Apply fuse_chunk_reshape_concat_pass",
164            )
165            pass_execution_and_save(
166                group_batch_fusion_passes,
167                gm,
168                example_inputs,
169                "[Pre grad(predispatch IR)] Apply group_batch_fusion",
170            )
171            pass_execution_and_save(
172                normalize_node_kwargs_pass,
173                gm,
174                example_inputs,
175                "[Pre grad(predispatch IR)]Apply normalize_node_kwargs_pass",
176            )
177            pass_execution_and_save(
178                fuse_chunk_squeeze_cat_pass.apply,
179                gm,
180                example_inputs,
181                "[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass",
182            )
183            pass_execution_and_save(
184                fuse_split_linear_add_pass.apply,
185                gm,
186                example_inputs,
187                "[Pre grad(predispatch IR)] Apply fuse_split_linear_add_pass",
188            )
189            pass_execution_and_save(
190                remove_reshape_pass.apply,
191                gm,
192                example_inputs,
193                "[Pre grad(predispatch IR)] Apply remove_reshape_pass",
194            )
195            pass_execution_and_save(
196                fuse_parallel_linear_pass,
197                gm,
198                example_inputs,
199                "[Pre grad(predispatch IR)] Apply fuse_parallel_linear_pass",
200            )
201            pass_execution_and_save(
202                lambda graph: remove_split_ops(graph.owning_module, shape_prop),
203                gm,
204                example_inputs,
205                "[Pre grad(predispatch IR)] Apply remove_split_ops",
206            )
207            # run before fuse_chunk_reshape_unsqueeze_concat_pass
208            pass_execution_and_save(
209                stack_to_unsqueeze_pass,
210                gm,
211                example_inputs,
212                "[Pre grad(predispatch IR)] Apply stack_to_unsqueeze_pass",
213            )
214            pass_execution_and_save(
215                fuse_chunk_reshape_unsqueeze_concat_pass,
216                gm,
217                example_inputs,
218                "[Pre grad(predispatch IR)] Apply fuse_chunk_reshape_unsqueeze_concat_pass",
219            )
220            # Remove noops at the end, which may be generated other passes.
221            pass_execution_and_save(
222                remove_noop_pass,
223                gm,
224                example_inputs,
225                "[Pre grad(predispatch IR)]Apply remove_noop pass",
226            )
227            shape_prop(gm)
228
229        else:
230            # We only log the graph with changes to avoid the excessive compilation time
231            # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/
232            if example_inputs is not None:
233                gm = fuse_fx(gm, example_inputs)
234            numpy_compat_normalization(gm.graph)
235            optimus_scuba_log["before_recompile_pre_grad"] = upload_graph(gm.graph)
236            group_batch_fusion_passes(gm.graph, pre_grad=True)
237            for pass_name in config.pre_grad_fusion_options:
238                # skip all patterns for group batch fusions
239                if pass_name in PRE_GRAD_FUSIONS:
240                    continue
241                pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name]
242                inductor_before_change = save_inductor_dict(
243                    [pattern_matcher_pass.pass_name]
244                )
245                # we support run same pattern multiple times, the default is to run only once
246                counter = config.pre_grad_fusion_options[pass_name].get("counter", 1)
247                for _ in range(counter):
248                    pattern_matcher_pass.apply(gm.graph)  # type: ignore[arg-type]
249                if not is_same_dict(counters["inductor"], inductor_before_change):
250                    optimus_scuba_log[
251                        f"{pattern_matcher_pass.pass_name}_pre_grad"
252                    ] = upload_graph(gm.graph)
253            # TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
254            efficient_conv_bn_eval_pass.apply(gm.graph)  # type: ignore[arg-type]
255
256    if config.pre_grad_custom_pass is not None:
257        with GraphTransformObserver(
258            gm, "pre_grad_custom_pass", config.trace.log_url_for_graph_xform
259        ):
260            config.pre_grad_custom_pass(gm.graph)
261    stable_topological_sort(gm.graph)
262
263    from .quantization import quant_lift_up
264
265    quant_lift_up(gm)
266
267    gm.graph.lint()
268    gm.recompile()
269    optimus_scuba_log["after_recompile_pre_grad"] = upload_graph(gm.graph)
270
271    if (
272        config.pattern_matcher
273        and hasattr(config, "fx_passes_numeric_check")
274        and config.fx_passes_numeric_check.get("pre_grad", False)
275        and example_inputs is not None
276    ):
277        from .numeric_utils import numeric_check_if_enabled
278
279        gm_after_fx_passes = gm.__copy__()
280        numeric_check_if_enabled(
281            gm_before_fx_passes,  # type: ignore[possibly-undefined]
282            gm_after_fx_passes,
283            example_inputs,
284            config.fx_passes_numeric_check.get("num_iterations", 1),
285            config.fx_passes_numeric_check.get("precision", 1e-4),
286        )
287
288    return gm
289
290
291def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
292    is_cpu = is_cpu_device(example_inputs)
293    # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
294    fake_mode = detect_fake_mode(example_inputs)
295
296    gm = sink_cat_after_pointwise(gm)
297    if config.permute_fusion and not is_cpu:
298        # For linear permute fusion, we need to check input info to identify
299        # and perform proper permutation/transpose
300        ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
301        with GraphTransformObserver(
302            gm, "linear_permute_fusion", config.trace.log_url_for_graph_xform
303        ):
304            gm = linear_permute_fusion(gm)
305        with GraphTransformObserver(
306            gm, "permute_linear_fusion", config.trace.log_url_for_graph_xform
307        ):
308            gm = permute_linear_fusion(gm)
309        with GraphTransformObserver(
310            gm, "permute_matmul_fusion", config.trace.log_url_for_graph_xform
311        ):
312            gm = permute_matmul_fusion(gm)
313
314    # make sure the autograd is disabled.
315    if torch.is_grad_enabled() or not is_cpu:
316        return gm
317    if config.freezing:
318        with GraphTransformObserver(
319            gm, "remove_identity", config.trace.log_url_for_graph_xform
320        ):
321            gm = remove_identity(gm)
322        with GraphTransformObserver(
323            gm, "fuse_conv_bn", config.trace.log_url_for_graph_xform
324        ):
325            gm = fuse_conv_bn(gm)
326    return gm
327
328
329def fetch_attr(target: str, mod):
330    target_atoms = target.split(".")
331    attr_itr = mod
332    for i, atom in enumerate(target_atoms):
333        if not hasattr(attr_itr, atom):
334            raise RuntimeError(
335                f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
336            )
337        attr_itr = getattr(attr_itr, atom)
338    return attr_itr
339
340
341def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
342    """
343    Removes all identity layers from the module.
344    """
345
346    class IdentityRemover(torch.fx.Transformer):
347        def call_module(self, target, args, kwargs):
348            if isinstance(self.submodules[target], nn.Identity):
349                assert len(args) == 1
350                return args[0]
351            else:
352                return super().call_module(target, args, kwargs)
353
354    return IdentityRemover(gm).transform()
355
356
357def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule:
358    """
359    Fuses Convolution/BN layers for inference purposes.
360    """
361    modules_patterns = [
362        (torch.nn.Conv1d, torch.nn.BatchNorm1d),
363        (torch.nn.Conv2d, torch.nn.BatchNorm2d),
364        (torch.nn.Conv3d, torch.nn.BatchNorm3d),
365    ]
366    module_function_patterns = [
367        (torch.nn.Conv1d, F.batch_norm),
368        (torch.nn.Conv2d, F.batch_norm),
369        (torch.nn.Conv3d, F.batch_norm),
370    ]
371    modules = dict(gm.named_modules())
372
373    class ConvBNFusion:
374        def __init__(
375            self,
376            bn_node,
377            conv_module,
378            bn_module=None,  # For BN Module
379            bn_running_mean=None,  # For Functional BN
380            bn_running_var=None,
381            bn_eps=None,
382            bn_weight=None,
383            bn_bias=None,
384        ) -> None:
385            self.bn_nodes = [
386                bn_node,
387            ]
388            self.conv_module = conv_module
389            self.bn_module = bn_module
390            self.bn_running_mean = bn_running_mean
391            self.bn_running_var = bn_running_var
392            self.bn_eps = bn_eps
393            self.bn_weight = bn_weight
394            self.bn_bias = bn_bias
395            self.fusion_enabled = True
396
397        def add_bn_node(self, bn_node):
398            self.bn_nodes.append(bn_node)
399
400        def disable_fusion(self):
401            self.fusion_enabled = False
402
403        def is_fusion_enabled(self):
404            return self.fusion_enabled
405
406    conv_bn_to_fuse: Dict[int, ConvBNFusion] = {}
407    for pattern in modules_patterns:
408        conv_bn_to_fuse.clear()
409        for node in gm.graph.nodes:
410            if matches_module_pattern(pattern, node, modules):
411                if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
412                    continue
413                conv = modules[node.args[0].target]
414                bn = modules[node.target]
415                eval_mode = all(not n.training for n in [conv, bn])
416                if not eval_mode:
417                    continue
418                if not bn.track_running_stats:
419                    continue
420
421                # Do hash based on the module name of conv
422                hash_id = hash(node.args[0].target)
423                if hash_id not in conv_bn_to_fuse:
424                    conv_bn_to_fuse[hash_id] = ConvBNFusion(node, conv, bn)
425                else:
426                    if bn == conv_bn_to_fuse[hash_id].bn_module:
427                        # Do fusion if same bn module
428                        conv_bn_to_fuse[hash_id].add_bn_node(node)
429                    else:
430                        # Disable the conv bn folding if conv shared by different bn
431                        conv_bn_to_fuse[hash_id].disable_fusion()
432
433        for conv_bn_fusion in conv_bn_to_fuse.values():
434            if conv_bn_fusion.is_fusion_enabled():
435                bn_nodes = conv_bn_fusion.bn_nodes
436                conv = conv_bn_fusion.conv_module
437                bn = conv_bn_fusion.bn_module
438
439                fused_conv = fuse_conv_bn_eval(conv, bn)
440                for bn_node in bn_nodes:
441                    replace_node_module(bn_node.args[0], modules, fused_conv)
442                    bn_node.replace_all_uses_with(bn_node.args[0])
443                    gm.graph.erase_node(bn_node)
444
445    gm.graph.lint()
446    for pattern in module_function_patterns:
447        conv_bn_to_fuse.clear()
448        for node in gm.graph.nodes:
449            if matches_module_function_pattern(pattern, node, modules):
450                # TODO: support kwargs.
451                if len(node.args) != 8:
452                    continue
453                conv = modules[node.args[0].target]
454                bn_training = node.args[5]
455                bn_eps = node.args[7]
456                if conv.training or bn_training:
457                    continue
458                if type(bn_eps) is not float:
459                    continue
460
461                def _used_by_same_conv_module(users):
462                    conv_module_name = users[0].args[0].target
463                    return all(
464                        conv_module_name == user.args[0].target for user in users
465                    )
466
467                bn_args_is_constant = all(
468                    n.op == "get_attr"
469                    and (len(n.users) == 1 or _used_by_same_conv_module(list(n.users)))
470                    for n in node.args[1:5]
471                )
472                if not bn_args_is_constant:
473                    continue
474                bn_running_mean = fetch_attr(node.args[1].target, gm)
475                bn_running_var = fetch_attr(node.args[2].target, gm)
476                bn_weight = fetch_attr(node.args[3].target, gm)
477                bn_bias = fetch_attr(node.args[4].target, gm)
478                if bn_running_mean is None or bn_running_var is None:
479                    continue
480
481                # Do hash based on the module name of conv
482                hash_id = hash(node.args[0].target)
483                if hash_id not in conv_bn_to_fuse:
484                    conv_bn_to_fuse[hash_id] = ConvBNFusion(
485                        node,
486                        conv,
487                        bn_running_mean=bn_running_mean,
488                        bn_running_var=bn_running_var,
489                        bn_eps=bn_eps,
490                        bn_weight=bn_weight,
491                        bn_bias=bn_bias,
492                    )
493                else:
494                    if (
495                        hash(bn_running_mean)
496                        == hash(conv_bn_to_fuse[hash_id].bn_running_mean)
497                        and hash(bn_running_var)
498                        == hash(conv_bn_to_fuse[hash_id].bn_running_var)
499                        and torch.allclose(
500                            torch.tensor(bn_eps),
501                            torch.tensor(conv_bn_to_fuse[hash_id].bn_eps),
502                        )
503                        and hash(bn_weight) == hash(conv_bn_to_fuse[hash_id].bn_weight)
504                        and hash(bn_bias) == hash(conv_bn_to_fuse[hash_id].bn_bias)
505                    ):
506                        # Do fusion if same functional bn
507                        conv_bn_to_fuse[hash_id].add_bn_node(node)
508                    else:
509                        # Disable the conv bn folding if conv shared by different bn
510                        conv_bn_to_fuse[hash_id].disable_fusion()
511
512        for conv_bn_fusion in conv_bn_to_fuse.values():
513            if conv_bn_fusion.is_fusion_enabled():
514                bn_nodes = conv_bn_fusion.bn_nodes
515                conv = conv_bn_fusion.conv_module
516                bn_running_mean = conv_bn_fusion.bn_running_mean
517                bn_running_var = conv_bn_fusion.bn_running_var
518                bn_eps = conv_bn_fusion.bn_eps
519                bn_weight = conv_bn_fusion.bn_weight
520                bn_bias = conv_bn_fusion.bn_bias
521
522                fused_conv = copy.deepcopy(conv)
523                fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
524                    fused_conv.weight,
525                    fused_conv.bias,
526                    bn_running_mean,
527                    bn_running_var,
528                    bn_eps,
529                    bn_weight,
530                    bn_bias,
531                )
532                for bn_node in bn_nodes:
533                    replace_node_module(bn_node.args[0], modules, fused_conv)
534                    bn_node.replace_all_uses_with(bn_node.args[0])
535                    gm.graph.erase_node(bn_node)
536    gm.graph.lint()
537    gm.recompile()
538
539    return gm
540
541
542class NormalizedLinearNode:
543    def __init__(self, node: torch.fx.Node) -> None:
544        assert node.op == "call_function"
545        assert node.target in [torch.nn.functional.linear]
546        self.node: torch.fx.Node = node
547
548    def get_input(self) -> torch.fx.Node:
549        if len(self.node.args) > 0:
550            return self.node.args[0]  # type: ignore[return-value]
551        else:
552            return self.node.kwargs["input"]  # type: ignore[return-value]
553
554    def get_weight(self) -> torch.fx.Node:
555        if len(self.node.args) > 1:
556            return self.node.args[1]  # type: ignore[return-value]
557        else:
558            return self.node.kwargs["weight"]  # type: ignore[return-value]
559
560    def get_bias(self) -> torch.fx.Node:
561        if len(self.node.args) > 2:
562            return self.node.args[2]  # type: ignore[return-value]
563        else:
564            return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None  # type: ignore[return-value]
565
566
567class NormalizedMatmulNode:
568    def __init__(self, node: torch.fx.Node) -> None:
569        assert node.op == "call_function"
570        assert node.target in [torch.bmm, torch.matmul]
571        self.node: torch.fx.Node = node
572
573    def get_input(self) -> torch.fx.Node:
574        if len(self.node.args) > 0:
575            return self.node.args[0]  # type: ignore[return-value]
576        else:
577            return self.node.kwargs["input"]  # type: ignore[return-value]
578
579    def get_other(self) -> torch.fx.Node:
580        if len(self.node.args) > 1:
581            return self.node.args[1]  # type: ignore[return-value]
582        else:
583            return self.node.kwargs["other"]  # type: ignore[return-value]
584
585
586def check_permute(node: torch.fx.Node) -> bool:
587    ranks = len(node.meta["tensor_meta"].shape)
588    if len(node.args) > 3:
589        permutation = [node.args[i] % ranks for i in range(1, ranks + 1)]  # type: ignore[operator]
590    elif (
591        "permutation" in node.kwargs
592        and node.kwargs["permutation"] is not None
593        and len(node.kwargs["permutation"]) > 2  # type: ignore[arg-type]
594    ):
595        permutation = [i % ranks for i in node.kwargs["permutation"]]  # type: ignore[union-attr]
596    else:
597        return False
598    allowed_permutation = list(range(ranks))
599    allowed_permutation[-1] = ranks - 2
600    allowed_permutation[-2] = ranks - 1
601    return permutation == allowed_permutation
602
603
604def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
605    def one_user(node):
606        users = list(node.users)
607        return users[0] if len(users) == 1 else None
608
609    def is_view(node):
610        view = {"view"}
611        return node.op == "call_method" and node.target in view
612
613    def is_pointwise_unary(node):
614        pointwise = {torch.relu, torch.tanh, "relu", "tanh"}
615        return node.op in {"call_function", "call_method"} and node.target in pointwise
616
617    g = module.graph
618    for node in g.nodes:
619        if node.op != "call_function" or node.target != torch.cat:
620            continue
621
622        cat_or_view = node
623        while True:
624            user = one_user(cat_or_view)
625            if not user or not is_view(user):
626                break
627            cat_or_view = user
628
629        if user and is_pointwise_unary(user):
630            with g.inserting_before(node):
631
632                def cat_args(tensors, dim=0):
633                    return tensors, dim
634
635                tensors, dim = cat_args(*node.args, **node.kwargs)
636                new_kwargs = {
637                    name: val for name, val in user.kwargs.items() if name != "input"
638                }
639                new_tensors = [
640                    g.create_node(user.op, user.target, args=(arg,), kwargs=new_kwargs)
641                    for arg in tensors
642                ]
643                new_cat = g.create_node(
644                    "call_function", torch.cat, args=(new_tensors, dim)
645                )
646                user.replace_all_uses_with(cat_or_view)
647                node.replace_all_uses_with(new_cat)
648                g.erase_node(user)
649                g.erase_node(node)
650    g.lint()
651    module.recompile()
652    return module
653
654
655def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
656    for node in module.graph.find_nodes(op="call_method", target="permute"):
657        if check_permute(node):
658            if len(node.args) > 0:
659                input_node = node.args[0]
660            else:
661                input_node = node.kwargs["input"]
662            if (
663                input_node.op == "call_function"
664                and input_node.target == torch.nn.functional.linear
665            ):
666                normalized = NormalizedLinearNode(input_node)
667                input = normalized.get_input()
668                weight = normalized.get_weight()
669                bias = normalized.get_bias()
670                with module.graph.inserting_before(node):
671                    fused_node = module.graph.call_function(
672                        linear_transpose, args=(input, weight, bias)
673                    )
674                    node.replace_all_uses_with(fused_node)
675                    module.graph.erase_node(node)
676                    if len(input_node.users) == 0:
677                        module.graph.erase_node(input_node)
678
679    module.graph.lint()
680    module.recompile()
681    return module
682
683
684# Y1 = X * W^T + bias
685# Y2 = Y1.permute(0, 2, 1)
686# ---->
687# Y2 = (W * X^T + bias.unsqueeze(-1))^T
688def linear_transpose(
689    input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
690) -> torch.Tensor:
691    if bias is None:
692        return torch.matmul(weight, input.transpose(-1, -2))
693    return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1)
694
695
696def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
697    for node in module.graph.find_nodes(
698        op="call_function", target=torch.nn.functional.linear
699    ):
700        if len(node.args) > 0:
701            input_node = node.args[0]
702        else:
703            input_node = node.kwargs["input"]
704        if (
705            input_node.op == "call_method"
706            and input_node.target == "permute"
707            and check_permute(input_node)
708        ):
709            normalized = NormalizedLinearNode(node)
710            if len(input_node.args) > 0:
711                input = input_node.args[0]
712            else:
713                input = input_node.kwargs["input"]
714            weight = normalized.get_weight()
715            bias = normalized.get_bias()
716            with module.graph.inserting_before(node):
717                fused_node = module.graph.call_function(
718                    transpose_linear, args=(input, weight, bias)
719                )
720                node.replace_all_uses_with(fused_node)
721                module.graph.erase_node(node)
722                if len(input_node.users) == 0:
723                    module.graph.erase_node(input_node)
724
725    module.graph.lint()
726    module.recompile()
727    return module
728
729
730def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule:
731    for node in itertools.chain(
732        module.graph.find_nodes(op="call_function", target=torch.bmm),
733        module.graph.find_nodes(op="call_function", target=torch.matmul),
734    ):
735        normalized = NormalizedMatmulNode(node)
736        input_A_node = normalized.get_input()
737        input_B_node = normalized.get_other()
738        input_A = input_A_node
739        input_B = input_B_node
740        Atrans = Btrans = False
741        if (
742            input_A_node.op == "call_method"
743            and input_A_node.target == "permute"
744            and check_permute(input_A_node)
745        ):
746            Atrans = True
747            if len(input_A_node.args) > 0:
748                input_A = input_A_node.args[0]  # type: ignore[assignment]
749            else:
750                input_A = input_A_node.kwargs["input"]  # type: ignore[assignment]
751
752        if (
753            input_B_node.op == "call_method"
754            and input_B_node.target == "permute"
755            and check_permute(input_B_node)
756        ):
757            Btrans = True
758            if len(input_B_node.args) > 0:
759                input_B = input_B_node.args[0]  # type: ignore[assignment]
760            else:
761                input_B = input_B_node.kwargs["input"]  # type: ignore[assignment]
762
763        if Atrans or Btrans:
764            with module.graph.inserting_before(node):
765                fused_node = module.graph.call_function(
766                    transpose_matmul,
767                    args=(input_A, input_B, Atrans, Btrans),
768                )
769            node.replace_all_uses_with(fused_node)
770            module.graph.erase_node(node)
771            if Atrans and len(input_A_node.users) == 0:
772                module.graph.erase_node(input_A_node)
773            if Btrans and len(input_B_node.users) == 0:
774                module.graph.erase_node(input_B_node)
775
776    module.graph.lint()
777    module.recompile()
778    return module
779
780
781# X1 = X.permute(0, 2, 1)
782# Y1 = X1 * W1^T + bias1
783# ---->
784# Y2 = X1.transpose(-1, -2) * W1^T + bias1
785def transpose_linear(
786    input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
787) -> torch.Tensor:
788    if bias is None:
789        return torch.matmul(input.transpose(-1, -2), weight.t())
790    return torch.matmul(input.transpose(-1, -2), weight.t()) + bias
791
792
793def transpose_matmul(
794    A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool
795) -> torch.Tensor:
796    if Atrans:
797        A = A.transpose(-1, -2)
798    if Btrans:
799        B = B.transpose(-1, -2)
800    return torch.matmul(A, B)
801