xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/mkldnn_fusion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import operator
4from functools import reduce
5from typing import Any, Tuple
6
7import torch
8from torch.fx.experimental.symbolic_shapes import has_free_symbols
9
10from .. import ir
11from ..lowering import lowerings as L
12from ..pattern_matcher import (
13    Arg,
14    CallFunction,
15    filter_nodes,
16    get_arg_value,
17    KeywordArg,
18    MULTIPLE,
19)
20from ..virtualized import ops, V
21from .freezing_patterns import register_freezing_graph_pattern
22from .post_grad import register_lowering_pattern
23from .quantization import (
24    _register_quantization_lowerings,
25    _register_quantization_weight_pack_pass,
26    _register_woq_lowerings,
27)
28
29
30if torch._C._has_mkldnn:
31    aten = torch.ops.aten
32    mkldnn = torch.ops.mkldnn
33    prims = torch.ops.prims
34
35    _conv_args = [Arg() for _ in range(10)]
36    _linear_args = [Arg() for _ in range(6)]
37    _conv_transpose_args = [Arg() for _ in range(11)]
38
39    def _conv_call(users=1):
40        return CallFunction(
41            mkldnn._convolution_pointwise.default, *_conv_args, _users=users
42        )
43
44    def _linear_call(users=1):
45        return CallFunction(
46            mkldnn._linear_pointwise.default, *_linear_args, _users=users
47        )
48
49    def _conv_transpose_call(users=1):
50        return CallFunction(
51            mkldnn._convolution_transpose_pointwise.default,
52            *_conv_transpose_args,
53            _users=users,
54        )
55
56    def _to_float(input_call, users=1):
57        return CallFunction(
58            prims.convert_element_type.default,
59            input_call,
60            KeywordArg("to_float"),
61            _users=users,
62        )
63
64    def _to_bf16(input_call):
65        return CallFunction(
66            prims.convert_element_type.default,
67            input_call,
68            KeywordArg("to_bf16"),
69            _users=1,
70        )
71
72    def _to_fp16(input_call):
73        return CallFunction(
74            prims.convert_element_type.default,
75            input_call,
76            KeywordArg("to_fp16"),
77            _users=1,
78        )
79
80    def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype):
81        # only insert to_dtype if lowp_dtype is True
82        computation_call = (
83            _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users)
84        )
85        out = unary_fusion(computation_call)
86        if lowp_dtype == torch.bfloat16:
87            return _to_bf16(out)
88        elif lowp_dtype == torch.float16:
89            return _to_fp16(out)
90        else:
91            return out
92
93    def _gelu_fusion_1(computation_call):
94        return CallFunction(
95            aten.mul,
96            CallFunction(aten.mul, computation_call, 0.5),
97            CallFunction(
98                aten.add,
99                CallFunction(
100                    aten.erf,
101                    CallFunction(aten.mul, computation_call, 0.7071067811865476),
102                ),
103                1,
104            ),
105        )
106
107    def _gelu_fusion_2(computation_call):
108        return CallFunction(
109            aten.mul,
110            CallFunction(aten.mul, computation_call, 0.5),
111            CallFunction(
112                aten.add,
113                CallFunction(
114                    aten.tanh,
115                    CallFunction(
116                        aten.mul,
117                        CallFunction(
118                            aten.add,
119                            computation_call,
120                            CallFunction(
121                                aten.mul,
122                                CallFunction(
123                                    aten.mul,
124                                    CallFunction(
125                                        aten.mul, computation_call, computation_call
126                                    ),
127                                    computation_call,
128                                ),
129                                0.044715,
130                            ),
131                        ),
132                        0.7978845608028654,
133                    ),
134                ),
135                1,
136            ),
137        )
138
139    def _hardswish_fusion(computation_call):
140        return CallFunction(
141            aten.div,
142            CallFunction(
143                aten.mul,
144                computation_call,
145                CallFunction(
146                    aten.clamp_max,
147                    CallFunction(
148                        aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
149                    ),
150                    6,
151                ),
152            ),
153            6,
154        )
155
156    def _silu_fusion(computation_call):
157        return CallFunction(
158            aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call)
159        )
160
161    def _hardsigmoid_fusion(computation_call):
162        return CallFunction(
163            aten.div,
164            CallFunction(
165                aten.clamp_max,
166                CallFunction(
167                    aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0
168                ),
169                6,
170            ),
171            6,
172        )
173
174    def _leaky_relu_fusion(computation_call):
175        return CallFunction(
176            aten.where,
177            CallFunction(aten.gt, computation_call, 0),
178            computation_call,
179            CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")),
180        )
181
182    def _hardtanh_fusion(computation_call):
183        return CallFunction(
184            aten.clamp_max,
185            CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")),
186            KeywordArg("max_value"),
187        )
188
189    def _combined_fusion(computation_call, elementwise_op):
190        return CallFunction(elementwise_op, computation_call)
191
192    # binary_op(other, computation_op)
193    def _binary_fusion_v1(computation_call, binary_fn):
194        return CallFunction(binary_fn, KeywordArg("other"), computation_call)
195
196    # binary_op(computation_op, other)
197    def _binary_fusion_v2(computation_call, binary_fn):
198        return CallFunction(binary_fn, computation_call, KeywordArg("other"))
199
200    def _is_single_computation_op(computation_op, lowp_dtype=None):
201        def fn(match):
202            computation_nodes = filter_nodes(match.nodes, computation_op)
203
204            if lowp_dtype:
205                output_node_meta = match.output_node().meta.get("val")
206                if output_node_meta.dtype != lowp_dtype:
207                    return False
208
209            if len(computation_nodes) < 1:
210                return False
211            if any(n.args[-3] != "none" for n in computation_nodes):
212                return False
213            return True
214
215        return fn
216
217    def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None):
218        def fn(match):
219            matched = _is_single_computation_op(computation_op, lowp_dtype)(match)
220            computation_node = filter_nodes(match.nodes, computation_op)[0]
221            if lowp_dtype:
222                conversion_dtype_nodes = filter_nodes(
223                    match.nodes, prims.convert_element_type.default
224                )
225                if len(conversion_dtype_nodes) != 2:
226                    return False
227                # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16
228                if computation_node == conversion_dtype_nodes[0].args[0]:
229                    to_float = conversion_dtype_nodes[0].args[1]
230                    to_lp = conversion_dtype_nodes[1].args[1]
231                else:
232                    to_float = conversion_dtype_nodes[1].args[1]
233                    to_lp = conversion_dtype_nodes[0].args[1]
234                matched = matched and to_float == torch.float and to_lp == lowp_dtype
235            return matched
236
237        return fn
238
239    def _register_unary_fusion_lowering(
240        pattern, unary_attr, computation_op, lowp_dtype=None
241    ):
242        @register_lowering_pattern(
243            pattern,
244            extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype),
245        )
246        def fn(match, *args, **kwargs):
247            computation_args = list(args)[:-3] + [
248                unary_attr.op_name,
249                unary_attr.scalars_attr,
250                unary_attr.algorithm_attr,
251            ]
252            return L[computation_op](*computation_args)
253
254        return fn
255
256    def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None):
257        @register_lowering_pattern(
258            pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype)
259        )
260        def fn(match, *args, **kwargs):
261            negative_slope = kwargs.get("negative_slope")
262            if isinstance(negative_slope, ir.TensorBox):
263                matched = False
264            else:  # inp is a Number
265                matched = True
266            if lowp_dtype:
267                dtype1 = kwargs.get("to_float")
268                dtype2 = (
269                    kwargs.get("to_bf16")
270                    if lowp_dtype == torch.bfloat16
271                    else kwargs.get("to_fp16")
272                )
273                matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
274            computation_args = list(args)
275            if matched:
276                computation_args = computation_args[:-3] + [
277                    "leaky_relu",
278                    [negative_slope],
279                    "",
280                ]
281                return L[computation_op](*computation_args)
282            else:
283                # computation_args += ["none", [], ""]
284                out = L[computation_op](*computation_args)
285                if lowp_dtype:
286                    out = L[prims.convert_element_type.default](out, dtype=torch.float)
287                out = L[aten.where](
288                    L[aten.gt](out, 0),
289                    out,
290                    L[aten.mul](out, negative_slope),
291                )
292                if lowp_dtype:
293                    out = L[prims.convert_element_type.default](out, dtype=dtype2)  # type: ignore[possibly-undefined]
294                return out
295
296        return fn
297
298    def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None):
299        @register_lowering_pattern(
300            pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype)
301        )
302        def fn(match, *args, **kwargs):
303            min_value = kwargs.get("min_value")
304            max_value = kwargs.get("max_value")
305            if isinstance(min_value, ir.TensorBox) or isinstance(
306                max_value, ir.TensorBox
307            ):
308                matched = False
309            else:  # inp is a Number
310                assert max_value is not None
311                matched = min_value <= max_value
312            if lowp_dtype:
313                dtype1 = kwargs.get("to_float")
314                dtype2 = (
315                    kwargs.get("to_bf16")
316                    if lowp_dtype == torch.bfloat16
317                    else kwargs.get("to_fp16")
318                )
319                matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype
320            computation_args = list(args)
321            if matched:
322                computation_args = computation_args[:-3] + [
323                    "hardtanh",
324                    [min_value, max_value],
325                    "",
326                ]
327                return L[computation_op](*computation_args)
328            else:
329                out = L[computation_op](*computation_args)
330                if lowp_dtype:
331                    out = L[prims.convert_element_type.default](out, dtype=torch.float)
332                out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value)
333                if lowp_dtype:
334                    out = L[prims.convert_element_type.default](out, dtype=dtype2)  # type: ignore[possibly-undefined]
335                return out
336
337        return fn
338
339    _binary_attr = {
340        aten.add: "add",
341        ops.add: "add",
342        aten.sub: "sub",
343        ops.sub: "sub",
344    }
345
346    def _is_valid_binary(match, fn):
347        binary_nodes = filter_nodes(match.nodes, fn)
348        if len(binary_nodes) < 1:
349            return False
350
351        def get_meta_value(argument: torch.fx.node.Argument):
352            # Only torch.fx.Node is expected to have meta.
353            if isinstance(argument, torch.fx.Node):
354                return argument.meta.get("val", None)
355            return None
356
357        if any(
358            not isinstance(get_meta_value(n.args[0]), torch.Tensor)
359            or not isinstance(get_meta_value(n.args[1]), torch.Tensor)
360            for n in binary_nodes
361        ):
362            return False
363        # check alpha is one.
364        if any(
365            get_arg_value(n, 2, kwarg_name="alpha") != 1.0
366            and get_arg_value(n, 2, kwarg_name="alpha") is not None
367            for n in binary_nodes
368        ):
369            return False
370        if any(
371            get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size()
372            or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device
373            or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype
374            for n in binary_nodes
375        ):
376            return False
377        # check args[0] and args[1] is not same
378        if any(n.args[0] == n.args[1] for n in binary_nodes):
379            return False
380        return True
381
382    def _is_valid_computation_binary(computation_op, binary_op, other_index=None):
383        def fn(match):
384            if not _is_single_computation_op(computation_op)(match):
385                return False
386            if not _is_valid_binary(match, binary_op):
387                return False
388            return True
389
390        return fn
391
392    def _get_remaining_users(extra_input_node, compute_node):
393        # Think about this pattern:
394        #      ReLU
395        #     /   \
396        #  Conv1
397        #   /      \
398        # Conv2
399        #   \      /
400        #      Add
401        # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add.
402        # The Conv1 is the ancestor node of the current compute node (Conv2).
403        # This indicates that the buffer of ReLU has completed all its usage,
404        # So we can safely make changes to it now by doing Conv2->Add inplace fusion.
405        # Take above case as example:
406        # * extra_input_node: ReLU
407        # * compute_node: Conv2
408        # _get_remaining_users will return the users of extra_input_node which are not
409        # ancestor node of compute_node.
410        def _is_ancestor_node(_current_node, _ancestor_node):
411            # Check whether _ancestor_node is the ancestor node of _current_node
412            _node_list = [_current_node]
413            _visited_nodes = set()
414            while len(_node_list) != 0:
415                _current_node = _node_list.pop(0)
416                if _current_node not in _visited_nodes:
417                    _visited_nodes.add(_current_node)
418                    if _current_node == _ancestor_node:
419                        return True
420                    elif isinstance(
421                        _current_node, torch.fx.Node
422                    ) and _current_node.op not in ["placeholder", "output", "get_attr"]:
423                        for input in _current_node.all_input_nodes:
424                            _node_list.append(input)  # noqa: PERF402
425            return False
426
427        return [
428            user
429            for user in list(extra_input_node.users)
430            if not _is_ancestor_node(compute_node, user)
431        ]
432
433    def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index):
434        def fn(match):
435            if not _is_valid_computation_binary(computation_op, binary_op)(match):
436                return False
437            binary_nodes = filter_nodes(match.nodes, binary_op)
438
439            def _get_compute_node(_binary_node, _other_index):
440                assert (
441                    len(_binary_node.all_input_nodes) == 2
442                ), "Binary node should have 2 input nodes."
443                _compute_index = 1 if (_other_index == 0) else 0
444                return _binary_node.args[_compute_index]
445
446            def _other_input_not_inplaceable(_binary_node, _other_index):
447                _compute_node = _get_compute_node(_binary_node, _other_index)
448                return (
449                    len(
450                        _get_remaining_users(
451                            _binary_node.args[_other_index], _compute_node
452                        )
453                    )
454                    > 1
455                    or _binary_node.args[_other_index] == _compute_node.args[0]
456                )
457
458            if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes):
459                return False
460            if any(
461                n.args[other_index].op in ["placeholder", "output"]
462                for n in binary_nodes
463            ):
464                return False
465            return True
466
467        return fn
468
469    def _register_binary_unary_fusion_lowering(
470        pattern,
471        computation_op,
472        binary_op,
473        fusion_op,
474        unary_attr=None,
475    ):
476        @register_lowering_pattern(
477            pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op)
478        )
479        def fn(match, *args, **kwargs):
480            other = kwargs.get("other")
481            assert isinstance(other, ir.TensorBox)
482            binary_attr = _binary_attr[binary_op]
483            args_list = list(args)
484            computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
485            if len(args_list) > 6:
486                if unary_attr is not None:
487                    computation_args += [
488                        1.0,
489                        unary_attr.op_name,
490                        unary_attr.scalars_attr,
491                        unary_attr.algorithm_attr,
492                    ]
493                else:
494                    computation_args += [1.0, None, [], None]
495            return L[fusion_op](*computation_args)
496
497        return fn
498
499    def _can_be_inplace(_other):
500        if isinstance(_other.data, ir.View):
501            return _can_be_inplace(_other.data)
502        else:
503            return not (
504                isinstance(_other.data, ir.ReinterpretView)
505                or len(_other.get_inputs_that_alias_output()) > 0
506            )
507
508    def _register_binary_unary_maybe_inplace_fusion_lowering(
509        pattern,
510        computation_op,
511        binary_op,
512        inplace_fusion_op,
513        outplace_fusion_op,
514        unary_attr=None,
515        other_index=None,
516    ):
517        @register_lowering_pattern(
518            pattern,
519            extra_check=_is_valid_computation_binary_inplace(
520                computation_op, binary_op, other_index
521            ),
522        )
523        def fn(match, *args, **kwargs):
524            other = kwargs.get("other")
525            assert isinstance(other, ir.TensorBox)
526            binary_attr = _binary_attr[binary_op]
527            args_list = list(args)
528            computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr]
529            if len(args_list) > 6:
530                if unary_attr is not None:
531                    computation_args += [
532                        1.0,
533                        unary_attr.op_name,
534                        unary_attr.scalars_attr,
535                        unary_attr.algorithm_attr,
536                    ]
537                else:
538                    computation_args += [1.0, None, [], None]
539            # Make sure the other is not an alias or mutation(fx side doesn't has such info).
540            other.realize()
541            if not _can_be_inplace(other):
542                return L[outplace_fusion_op](*computation_args)
543            return L[inplace_fusion_op](*computation_args)
544
545        return fn
546
547    computation_ops = [
548        mkldnn._convolution_pointwise.default,
549        mkldnn._linear_pointwise.default,
550        mkldnn._convolution_transpose_pointwise.default,
551    ]
552
553    class UnaryAttr:
554        def __init__(
555            self, op_name: str, scalars_attr=None, algorithm_attr=None
556        ) -> None:
557            self.op_name = op_name
558            self.scalars_attr = scalars_attr if scalars_attr else []
559            self.algorithm_attr = algorithm_attr if algorithm_attr else ""
560
561    def _register_unary_fusion():
562        computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call]
563
564        def _unary_fusion_patterns(lowp_dtype):
565            replacement_unary_fusion_patterns = {
566                UnaryAttr("gelu", algorithm_attr="tanh"): [
567                    _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype)
568                    for call_fn in computation_call_fns
569                ],
570                UnaryAttr("gelu", algorithm_attr="none"): [
571                    _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype)
572                    for call_fn in computation_call_fns
573                ],
574                UnaryAttr("hardswish"): [
575                    _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype)
576                    for call_fn in computation_call_fns
577                ],
578                UnaryAttr("hardsigmoid"): [
579                    _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype)
580                    for call_fn in computation_call_fns
581                ],
582                UnaryAttr("swish"): [
583                    _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype)
584                    for call_fn in computation_call_fns
585                ],
586            }
587            if not lowp_dtype:
588                call_user1 = [call_fn(users=1) for call_fn in computation_call_fns]
589                replacement_unary_fusion_patterns.update(
590                    {
591                        UnaryAttr("relu"): [
592                            _combined_fusion(u, aten.relu) for u in call_user1
593                        ],
594                        UnaryAttr("sigmoid"): [
595                            _combined_fusion(u, aten.sigmoid) for u in call_user1
596                        ],
597                        UnaryAttr("tanh"): [
598                            _combined_fusion(u, aten.tanh) for u in call_user1
599                        ],
600                    }
601                )
602
603            return replacement_unary_fusion_patterns
604
605        for lowp_dtype in [torch.bfloat16, torch.float16, None]:
606            replace_patterns = _unary_fusion_patterns(lowp_dtype)
607            for unary_attr, patterns in replace_patterns.items():
608                _register_unary_fusion_lowering(
609                    patterns[0], unary_attr, computation_ops[0], lowp_dtype
610                )
611                _register_unary_fusion_lowering(
612                    patterns[1], unary_attr, computation_ops[1], lowp_dtype
613                )
614                _register_unary_fusion_lowering(
615                    patterns[2], unary_attr, computation_ops[2], lowp_dtype
616                )
617            _leaky_relu_patterns = [
618                _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype)
619                for call_fn in computation_call_fns
620            ]
621            for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops):
622                _register_leaky_relu_fusion_lowering(
623                    pattern, computation_op, lowp_dtype
624                )
625            hardtanh_patterns = [
626                _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype)
627                for call_fn in computation_call_fns
628            ]
629            for pattern, computation_op in zip(hardtanh_patterns, computation_ops):
630                _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype)
631
632    def _register_inplace_fusion():
633        binary_ops = [aten.add, ops.add]
634        inplace_fusion_op = mkldnn._convolution_pointwise_.binary
635        outplace_fusion_op = mkldnn._convolution_pointwise.binary
636        conv_call = _conv_call(users=1)
637        conv_op = computation_ops[0]
638        for binary_op in binary_ops:
639            binary_v1 = _binary_fusion_v1(conv_call, binary_op)
640            binary_unary_v1 = _combined_fusion(binary_v1, aten.relu)
641            _register_binary_unary_maybe_inplace_fusion_lowering(
642                binary_unary_v1,
643                conv_op,
644                binary_op,
645                inplace_fusion_op,
646                outplace_fusion_op,
647                other_index=0,
648                unary_attr=UnaryAttr("relu"),
649            )
650            _register_binary_unary_maybe_inplace_fusion_lowering(
651                binary_v1,
652                conv_op,
653                binary_op,
654                inplace_fusion_op,
655                outplace_fusion_op,
656                other_index=0,
657            )
658            binary_v2 = _binary_fusion_v2(conv_call, binary_op)
659            binary_unary_v2 = _combined_fusion(binary_v2, aten.relu)
660            _register_binary_unary_maybe_inplace_fusion_lowering(
661                binary_unary_v2,
662                conv_op,
663                binary_op,
664                inplace_fusion_op,
665                outplace_fusion_op,
666                other_index=1,
667                unary_attr=UnaryAttr("relu"),
668            )
669            _register_binary_unary_maybe_inplace_fusion_lowering(
670                binary_v2,
671                conv_op,
672                binary_op,
673                inplace_fusion_op,
674                outplace_fusion_op,
675                other_index=1,
676            )
677
678    def _register_binary_fusion():
679        binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
680        fusion_ops = [
681            mkldnn._convolution_pointwise.binary,
682            mkldnn._linear_pointwise.binary,
683        ]
684        _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)]
685        for computation_call, computation_op, fusion_op in zip(
686            _computation_user_1, computation_ops[:-1], fusion_ops
687        ):
688            for binary_op in binary_ops:
689                pattern = _binary_fusion_v2(computation_call, binary_op)
690                _register_binary_unary_fusion_lowering(
691                    pattern, computation_op, binary_op, fusion_op
692                )
693
694            for binary_op in [aten.add, ops.add]:
695                pattern = _binary_fusion_v1(computation_call, binary_op)
696                _register_binary_unary_fusion_lowering(
697                    pattern, computation_op, binary_op, fusion_op
698                )
699
700    def _register_binary_unary_fusion():
701        binary_ops = [aten.add, ops.add, aten.sub, ops.sub]
702        fusion_ops = [mkldnn._convolution_pointwise.binary]
703        _computation_user_1 = [_conv_call(users=1)]
704        for computation_call, computation_op, fusion_op in zip(
705            _computation_user_1, computation_ops[:-1], fusion_ops
706        ):
707            for binary_op in binary_ops:
708                pattern_v1 = _combined_fusion(
709                    _binary_fusion_v2(computation_call, binary_op), aten.relu
710                )
711                _register_binary_unary_fusion_lowering(
712                    pattern_v1,
713                    computation_op,
714                    binary_op,
715                    fusion_op,
716                    unary_attr=UnaryAttr("relu"),
717                )
718            for binary_op in [aten.add, ops.add]:
719                pattern_v2 = _combined_fusion(
720                    _binary_fusion_v1(computation_call, binary_op), aten.relu
721                )
722                _register_binary_unary_fusion_lowering(
723                    pattern_v2,
724                    computation_op,
725                    binary_op,
726                    fusion_op,
727                    unary_attr=UnaryAttr("relu"),
728                )
729
730    def _recover_linear():
731        # convert reshape+linear+reshape to a single linear for applying fusion path.
732        @register_freezing_graph_pattern(
733            CallFunction(
734                aten.reshape.default,
735                CallFunction(
736                    mkldnn._linear_pointwise.default,
737                    CallFunction(
738                        aten.reshape.default,
739                        Arg(),
740                        KeywordArg("reshape_1"),
741                        _users=MULTIPLE,
742                    ),
743                    Arg(),
744                    Arg(),
745                    Arg(),
746                    Arg(),
747                    Arg(),
748                ),
749                KeywordArg("reshape_2"),
750            ),
751            pass_number=1,
752        )
753        def reshape_linear_reshape_pattern(match, *args, **kwargs):
754            def get_val(val):
755                return val if isinstance(val, int) else val.meta.get("val")
756
757            reshape_1 = kwargs.get("reshape_1")
758            reshape_2 = kwargs.get("reshape_2")
759            assert isinstance(reshape_1, list)
760            assert isinstance(reshape_2, list)
761            assert len(reshape_1) == 2
762
763            graph = match.graph
764            reshape_2_node = match.output_node()
765            linear_input_node = reshape_2_node.args[0].args[0].args[0]
766            # check linear's input's shape[:-1] == reshape_2[:-1]
767            # and check product(reshape_2[:-1]) == reshape_1[0]
768            can_remove_reshape = linear_input_node.meta.get("val").shape[
769                :-1
770            ] == torch.Size([get_val(val) for val in reshape_2[:-1]])
771            can_remove_reshape = can_remove_reshape and (
772                reduce(
773                    operator.mul,
774                    [get_val(val) for val in reshape_2[:-1]],
775                )
776                == get_val(reshape_1[0])
777            )
778
779            if can_remove_reshape:
780                repl = graph.call_function(mkldnn._linear_pointwise.default, args)
781                repl.meta.update(reshape_2_node.meta)
782                reshape_2_node.replace_all_uses_with(repl)
783                old_linear_node = reshape_2_node.args[0]
784                reshape_1_node = old_linear_node.args[0]
785                graph.erase_node(reshape_2_node)
786                graph.erase_node(old_linear_node)
787                if len(reshape_1_node.users) == 0:
788                    graph.erase_node(reshape_1_node)
789
790        def is_linear_add_bias(match):
791            add_node = match.output_node()
792            linear_node = add_node.args[0]
793            packed_weight_node = linear_node.args[1]
794            assert packed_weight_node.target == mkldnn._reorder_linear_weight
795            transpose_weight_node = packed_weight_node.args[0]
796            assert transpose_weight_node.target == aten.permute.default
797            weight_meta = transpose_weight_node.args[0].meta.get("val")
798            bias_node = add_node.args[1]
799            if isinstance(bias_node, int):
800                # we only folding bias if it is a constant
801                return False
802            bias_meta = add_node.args[1].meta.get("val")
803            if weight_meta is None or bias_meta is None:
804                return False
805            assert weight_meta.dtype in (
806                torch.bfloat16,
807                torch.float16,
808            )
809            if bias_meta.dtype != weight_meta.dtype:
810                return False
811            return (
812                linear_node.args[2] is None
813                and bias_meta.dim() == 1
814                and bias_meta.size(0) == weight_meta.size(1)
815            )
816
817        # convert linear+bias to a single linear for applying fusion path.
818        @register_freezing_graph_pattern(
819            CallFunction(
820                aten.add.Tensor,
821                CallFunction(mkldnn._linear_pointwise.default, *_linear_args),
822                Arg(),
823            ),
824            pass_number=1,
825            extra_check=is_linear_add_bias,
826        )
827        def linear_bias_pattern(match, *args):
828            graph = match.graph
829            add_node = match.output_node()
830            linear_node = add_node.args[0]
831            new_args = list(linear_node.args)
832            new_args[2] = add_node.args[1]
833            repl = graph.call_function(
834                mkldnn._linear_pointwise.default, tuple(new_args)
835            )
836            repl.meta.update(add_node.meta)
837            add_node.replace_all_uses_with(repl)
838            match.erase_nodes()
839
840    def _is_packable_mkldnn_rnn_layer(match):
841        lstm_node = match.output_node()
842        POS_WEIGHTS = [1, 2]
843        POS_INPUTS = [0, 5, 6]
844        POS_ARGS = POS_WEIGHTS + POS_INPUTS
845        # Weights should be Constant
846        if any(
847            lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS
848        ):
849            return False
850
851        # Meta info for weights and inputs should be available
852        if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS):
853            return False
854
855        # Check device
856        if any(
857            lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu"
858            for POS_ARG in POS_ARGS
859        ):
860            return False
861
862        # Check dtype
863        if any(
864            lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16
865            and not mkldnn._is_mkldnn_bf16_supported()
866            for POS_ARG in POS_ARGS
867        ):
868            return False
869        if any(
870            lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16
871            and not mkldnn._is_mkldnn_fp16_supported()
872            for POS_ARG in POS_ARGS
873        ):
874            return False
875
876        return True
877
878    def _is_packable_convolution(match):
879        """
880        Check if the node is supported for MKLDNN convolution.
881        """
882        conv_node = match.output_node()
883        input_meta_value = conv_node.args[0].meta.get("val")
884        weight_meta_value = conv_node.args[1].meta.get("val")
885        if input_meta_value is None or weight_meta_value is None:
886            return False
887        input_size = input_meta_value.shape
888        if conv_node.args[1].op != "get_attr":
889            return False
890        for meta_value in [input_meta_value, weight_meta_value]:
891            if (
892                meta_value is None
893                or meta_value.device.type != "cpu"
894                or (meta_value.dim() != 4 and meta_value.dim() != 5)
895            ):
896                return False
897        if (
898            input_meta_value.dtype == torch.bfloat16
899            or weight_meta_value.dtype == torch.bfloat16
900        ):
901            if not mkldnn._is_mkldnn_bf16_supported():
902                return False
903        if (
904            input_meta_value.dtype == torch.float16
905            or weight_meta_value.dtype == torch.float16
906        ):
907            if not mkldnn._is_mkldnn_fp16_supported():
908                return False
909        is_transposed = conv_node.args[-3]
910        if is_transposed:
911            # TODO: Support dynamic shape case for MKLDNN conv transpose.
912            if has_free_symbols(input_size):
913                return False
914            groups = conv_node.args[-1]
915            in_channels = weight_meta_value.size(0)
916            # doesn't support group_depthwise_conv_transpose.
917            if groups > 1 and groups == in_channels:
918                return False
919            # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big
920            output_paddings = conv_node.args[-2]
921            strides = conv_node.args[3]
922            if any(
923                output_padding >= stride
924                for output_padding, stride in zip(output_paddings, strides)
925            ):
926                return False
927        return True
928
929    def _is_packable_linear(match):
930        """
931        Check if the node is supported for MKLDNN linear.
932        """
933        linear_node = match.output_node()
934        # mkldnn linear only supports beta=1or0 and alpha=1
935        if linear_node.target == aten.addmm.default:
936            alpha = linear_node.kwargs.get("alpha", 1.0)
937            beta = linear_node.kwargs.get("beta", 1.0)
938            if (beta != 0.0 and beta != 1.0) or alpha != 1.0:
939                return False
940        # weight_idx is 1 for aten.mm and is 2 for aten.addmm
941        weight_idx = 2 if linear_node.target == aten.addmm.default else 1
942        if linear_node.args[weight_idx].op != "get_attr":
943            return False
944        input_meta_value = linear_node.args[weight_idx - 1].meta.get("val")
945        weight_meta_value = linear_node.args[weight_idx].meta.get("val")
946        if input_meta_value is None or weight_meta_value is None:
947            return False
948        batch_size = input_meta_value.shape[0]
949        if (
950            input_meta_value.dtype == torch.float64
951            or weight_meta_value.dtype == torch.float64
952        ):
953            return False
954        is_lp_weight = weight_meta_value.dtype in (
955            torch.bfloat16,
956            torch.float16,
957        )
958        # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
959        # on aarch64, use mkldnn op for fp32 as well if acl is enabled
960        if (
961            not is_lp_weight
962            and not mkldnn._is_mkldnn_acl_supported()
963            and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
964        ):
965            return False
966        for meta_value in [input_meta_value, weight_meta_value]:
967            if (
968                meta_value is None
969                or meta_value.device.type != "cpu"
970                or meta_value.dim() != 2
971            ):
972                return False
973        if weight_idx == 2:
974            bias_meta_value = linear_node.args[0].meta.get("val")
975            if (
976                bias_meta_value is None
977                or meta_value.device.type != "cpu"
978                or bias_meta_value.dim() != 1
979                or bias_meta_value.size(0) != weight_meta_value.size(1)
980            ):
981                return False
982
983        if (
984            input_meta_value.dtype == torch.bfloat16
985            or weight_meta_value.dtype == torch.bfloat16
986        ):
987            if not mkldnn._is_mkldnn_bf16_supported():
988                return False
989        if (
990            input_meta_value.dtype == torch.float16
991            or weight_meta_value.dtype == torch.float16
992        ):
993            if not mkldnn._is_mkldnn_fp16_supported():
994                return False
995        return True
996
997    _aten_conv_args = (
998        Arg(),
999        Arg(),
1000        Arg(),
1001        Arg(),
1002        Arg(),
1003        Arg(),
1004        KeywordArg("is_transposed"),
1005        Arg(),
1006        Arg(),
1007    )
1008
1009    _aten_mkldnn_rnn_layer_args = (
1010        Arg(),  # input
1011        Arg(),  # weight0
1012        Arg(),  # weight1
1013        Arg(),  # weight2
1014        Arg(),  # weight3
1015        Arg(),  # hx_
1016        Arg(),  # cx_
1017        KeywordArg("reverse"),  # reverse
1018        Arg(),  # batch_sizes
1019        Arg(),  # mode
1020        Arg(),  # hidden_size
1021        Arg(),  # num_layers
1022        Arg(),  # has_biases
1023        Arg(),  # bidirectional
1024        Arg(),  # batch_first
1025        Arg(),  # train
1026    )
1027
1028    def _register_weight_pack_pass():
1029        @register_freezing_graph_pattern(
1030            CallFunction(aten.convolution.default, *_aten_conv_args),
1031            extra_check=_is_packable_convolution,
1032        )
1033        def convolution(match, *args, **kwargs):
1034            is_transposed = kwargs.get("is_transposed")
1035            assert isinstance(is_transposed, bool)
1036            graph = match.graph
1037            conv_node = match.output_node()
1038            input_size = conv_node.args[0].meta.get("val").shape
1039            with graph.inserting_before(conv_node):
1040                constant_args = [args[4], args[3], args[5], args[-1]]
1041                packed_weight_op = mkldnn._reorder_convolution_weight
1042                packed_conv_op = mkldnn._convolution_pointwise.default
1043                if is_transposed:
1044                    constant_args.insert(1, args[-2])  # output_padding
1045                    packed_weight_op = mkldnn._reorder_convolution_transpose_weight
1046                    packed_conv_op = mkldnn._convolution_transpose_pointwise.default
1047                if not has_free_symbols(input_size):
1048                    packed_weight_inputs = (
1049                        (args[1],) + tuple(constant_args) + (input_size,)
1050                    )
1051                    packed_weight_node = graph.create_node(
1052                        "call_function", packed_weight_op, args=packed_weight_inputs
1053                    )
1054                else:
1055                    assert not is_transposed
1056                    # For dynamic shape case, we need to pack weight in runtime.
1057                    packed_weight_node = args[1]
1058                packed_conv_inputs = (
1059                    (args[0], packed_weight_node, args[2])
1060                    + tuple(constant_args)
1061                    + ("none", [], "")
1062                )
1063                packed_conv_node = graph.create_node(
1064                    "call_function", packed_conv_op, tuple(packed_conv_inputs)
1065                )
1066                conv_node.replace_all_uses_with(packed_conv_node)
1067                packed_conv_node.meta.update(conv_node.meta)
1068                graph.erase_node(conv_node)
1069
1070        @register_freezing_graph_pattern(
1071            CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args),
1072            extra_check=_is_packable_mkldnn_rnn_layer,
1073        )
1074        def mkldnn_rnn_layer(match, *args, **kwargs):
1075            def get_item(graph, node, index):
1076                return graph.call_function(operator.getitem, (node, index))
1077
1078            graph = match.graph
1079            lstm_node = match.output_node()
1080            input = args[0]
1081            weight0, weight1 = args[1:3]
1082            reverse = kwargs.get("reverse")
1083            packed_lstm_op = aten.mkldnn_rnn_layer.default
1084            hidden_size = args[9]
1085            has_biases = args[11]
1086            batch_first = args[13]
1087            with graph.inserting_before(lstm_node):
1088                packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default
1089                packed_weight_inputs = (
1090                    weight0,
1091                    weight1,
1092                    hidden_size,
1093                    reverse,
1094                    has_biases,
1095                    batch_first,
1096                )
1097                packed_weight_node = graph.create_node(
1098                    "call_function", packed_weight_op, packed_weight_inputs, {}, "name"
1099                )
1100                packed_weight_items = [
1101                    get_item(graph, packed_weight_node, i) for i in range(2)
1102                ]
1103                pack_lstm_inputs = (
1104                    args[0],
1105                    *packed_weight_items,
1106                    args[3],
1107                    args[4],
1108                    args[5],
1109                    args[6],
1110                    reverse,
1111                    *args[7:],
1112                )
1113
1114                packed_lstm_node = graph.create_node(
1115                    "call_function", packed_lstm_op, args=pack_lstm_inputs
1116                )
1117                lstm_node.replace_all_uses_with(packed_lstm_node)
1118                packed_lstm_node.meta.update(lstm_node.meta)
1119                graph.erase_node(lstm_node)
1120
1121        @register_freezing_graph_pattern(
1122            CallFunction(
1123                aten.addmm.default,
1124                Arg(),
1125                Arg(),
1126                Arg(),
1127                beta=KeywordArg("beta"),
1128                alpha=KeywordArg("alpha"),
1129            ),
1130            extra_check=_is_packable_linear,
1131        )
1132        @register_freezing_graph_pattern(
1133            CallFunction(aten.mm.default, Arg(), Arg()),
1134            extra_check=_is_packable_linear,
1135        )
1136        def linear(match, *args, **kwargs):
1137            graph = match.graph
1138            linear_node = match.output_node()
1139            input = args[0] if linear_node.target == aten.mm.default else args[1]
1140            bias = (
1141                None
1142                if linear_node.target == aten.mm.default
1143                or (
1144                    linear_node.target == aten.addmm.default
1145                    and linear_node.kwargs.get("beta", 1.0) == 0.0
1146                )
1147                else args[0]
1148            )
1149            weight = args[1] if linear_node.target == aten.mm.default else args[2]
1150            with graph.inserting_before(linear_node):
1151                transpose_weight_node = graph.create_node(
1152                    "call_function", aten.permute.default, (weight, (1, 0))
1153                )
1154                weight_dtype = weight.meta.get("val").dtype
1155                is_lp_weight = weight_dtype in (
1156                    torch.bfloat16,
1157                    torch.float16,
1158                )
1159                batch_size = input.meta.get("val").shape[0]
1160                if has_free_symbols(batch_size):
1161                    assert (
1162                        is_lp_weight or mkldnn._is_mkldnn_acl_supported()
1163                    ), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
1164                # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
1165                packed_weight_inputs = (
1166                    transpose_weight_node,
1167                    batch_size.node.shape_env.size_hint(batch_size.node.expr)
1168                    if has_free_symbols(batch_size)
1169                    else batch_size,
1170                )
1171                # MKL packed matrix can't be copied to a different address because the internal implementation
1172                # depends on the alignment of internally-stored metadata.
1173                # In aot mode, we need to firstly save the packed weight, when loading it,
1174                # it will be in a different address which doesn't work.
1175                # Disable MKL prepack linear in AOT mode
1176                packed_weight_op = (
1177                    mkldnn._reorder_linear_weight
1178                    if (
1179                        is_lp_weight
1180                        or mkldnn._is_mkldnn_acl_supported()
1181                        or V.aot_compilation is True
1182                    )
1183                    else torch.ops.mkl._mkl_reorder_linear_weight
1184                )
1185                packed_weight_node = graph.create_node(
1186                    "call_function", packed_weight_op, args=packed_weight_inputs
1187                )
1188
1189                packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node)
1190                if (
1191                    is_lp_weight
1192                    or mkldnn._is_mkldnn_acl_supported()
1193                    or V.aot_compilation is True
1194                ):
1195                    packed_linear_inputs += (bias, "none", [], "")
1196                    packed_linear_op = mkldnn._linear_pointwise.default
1197                else:
1198                    packed_linear_inputs += (transpose_weight_node, bias, batch_size)
1199                    packed_linear_op = torch.ops.mkl._mkl_linear
1200                packed_linear_node = graph.create_node(
1201                    "call_function", packed_linear_op, packed_linear_inputs
1202                )
1203                linear_node.replace_all_uses_with(packed_linear_node)
1204                packed_linear_node.meta.update(linear_node.meta)
1205                graph.erase_node(linear_node)
1206
1207    def _eliminate_duplicate_packed_nodes(gm):
1208        """
1209        Combine packed weight nodes with the same inputs to reduce memory usage.
1210        for example:
1211        class Model(nn.Module):
1212            def __init__(self) -> None:
1213                super().__init__()
1214                self.linear = nn.Linear(32, 32, bias=True)
1215
1216            def forward(self, x):
1217                return self.linear(self.linear(x))
1218
1219        the above's packed weight nodes are duplicate if two linear calls have same input size.
1220        """
1221        if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
1222            return gm
1223
1224        packed_weight_ops = [
1225            torch._C._nn.mkldnn_reorder_conv2d_weight,
1226            torch._C._nn.mkldnn_reorder_conv3d_weight,
1227            mkldnn._reorder_convolution_transpose_weight,
1228            mkldnn._reorder_linear_weight,
1229            mkldnn._reorder_mkldnn_rnn_layer_weight,
1230        ]
1231        if torch._C.has_mkl:
1232            packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight)
1233
1234        for node in gm.graph.nodes:
1235            if node.target in packed_weight_ops and len(node.args[0].users) > 1:
1236                for user_node in list(node.args[0].users.keys()):
1237                    if (
1238                        user_node.target == node.target
1239                        and user_node != node
1240                        and user_node.args == node.args
1241                    ):
1242                        user_node.replace_all_uses_with(node)
1243                        gm.graph.erase_node(user_node)
1244
1245    @functools.lru_cache(None)
1246    def _mkldnn_fusion_init():
1247        # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now.
1248        # Otherwise even the matmul or innerproduct can not be accelerated with acl
1249        if (
1250            torch.backends.mkldnn.enabled
1251            and torch.backends.mkldnn.is_available()
1252            and not torch.ops.mkldnn._is_mkldnn_acl_supported()
1253        ):
1254            _register_unary_fusion()
1255            _register_inplace_fusion()
1256            _register_binary_unary_fusion()
1257            _register_binary_fusion()
1258            _register_quantization_lowerings()
1259            _register_woq_lowerings()
1260
1261    @functools.lru_cache(None)
1262    def _mkldnn_weight_pack_init():
1263        if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available():
1264            _register_weight_pack_pass()
1265            _recover_linear()
1266            _register_quantization_weight_pack_pass()
1267