xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import copy
4import functools
5import itertools
6import math
7import operator
8from typing import Any, Tuple
9
10import torch
11from torch._dynamo.utils import counters
12from torch.fx.experimental.symbolic_shapes import has_free_symbols
13from torch.fx.node import map_arg
14
15from ..lowering import lowerings as L, require_channels_last
16from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match
17from ..utils import pad_listlike
18from .freezing_patterns import register_freezing_graph_pattern
19from .post_grad import register_lowering_pattern
20
21
22aten = torch.ops.aten
23prims = torch.ops.prims
24quantized_decomposed = torch.ops.quantized_decomposed
25quantized = torch.ops.quantized
26
27# Only for per tensor quant since permute may changes the channel idx
28_PER_TENSOR_QUANTIZE_OPS = [
29    quantized_decomposed.quantize_per_tensor.default,
30    quantized_decomposed.quantize_per_tensor.tensor,
31]
32
33_VIEW_OPS = [
34    aten.transpose.int,
35    aten.permute.default,
36    aten.view.default,
37]
38
39"""
40The quantization.py file primarily incorporates passes related to quantization fusion
41in inductor, includes:
421. Dequant Promotion;
432. Conv/GEMM weight prepack with oneDNN Library;
443. Conv/GEMM quantization fusion with output quant node (if have);
454. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more;
46
47It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference
48of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is
491. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM.
502. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node.
51Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16
52quantization.
53"""
54
55
56def _get_pattern_output_dtype(match: Match):
57    """
58    Get the pattern's output dtype from node's meta
59    Assume only 1 output node in this matched pattern.
60    """
61    pattern_output_nodes = match.output_nodes()
62    assert len(pattern_output_nodes) == 1
63    output_node = pattern_output_nodes[0]
64    assert isinstance(output_node, torch.fx.Node)
65    output_dtype = output_node.meta["val"].dtype
66    assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16]
67    return output_dtype
68
69
70def _may_generate_pattern_with_dtype_convert(
71    pattern, dtype=Arg(), with_dtype_convert=True, users=1
72):
73    if with_dtype_convert:
74        return CallFunction(
75            prims.convert_element_type.default,
76            pattern,
77            dtype,
78            _users=users,
79        )
80    else:
81        return pattern
82
83
84def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True):
85    if with_reshape:
86        return CallFunction(
87            torch.ops.aten.reshape.default,
88            pattern,
89            reshape_size,
90        )
91    else:
92        return pattern
93
94
95def _generate_linear_t_pattern(
96    _dequant_per_channel_pattern,
97    dtype,
98):
99    assert dtype in [torch.float32, torch.bfloat16]
100    t_pattern = CallFunction(
101        aten.permute.default,
102        _may_generate_pattern_with_dtype_convert(
103            _dequant_per_channel_pattern,
104            KeywordArg("autocast_wgt_dtype"),
105            dtype == torch.bfloat16,
106        ),
107        KeywordArg("permute_axes"),
108    )
109    return t_pattern
110
111
112def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16):
113    # only insert to_dtype if is_bf16 is True
114    computation_call = _may_generate_pattern_with_dtype_convert(
115        call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users
116    )
117    return unary_fusion(computation_call)
118
119
120def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False):
121    dequantize_per_tensor_activation_pattern = CallFunction(
122        quantized_decomposed.dequantize_per_tensor.tensor
123        if is_tensor_overload
124        else quantized_decomposed.dequantize_per_tensor.default,
125        KeywordArg("x"),
126        KeywordArg("x_scale"),
127        KeywordArg("x_zp"),
128        KeywordArg("x_quant_min"),
129        KeywordArg("x_quant_max"),
130        KeywordArg("x_dq_dtype"),
131    )
132    return dequantize_per_tensor_activation_pattern
133
134
135dequantize_per_channel_weight_pattern = CallFunction(
136    quantized_decomposed.dequantize_per_channel.default,
137    KeywordArg("q_weight"),
138    KeywordArg("w_scale"),
139    KeywordArg("w_zp"),
140    KeywordArg("w_axis"),
141    KeywordArg("w_quant_min"),
142    KeywordArg("w_quant_max"),
143    KeywordArg("w_dtype"),
144)
145
146dequantize_per_channel_to_bf16_weight_pattern = (
147    _may_generate_pattern_with_dtype_convert(
148        dequantize_per_channel_weight_pattern,
149        KeywordArg("autocast_wgt_dtype"),
150    )
151)
152
153dequantize_per_channel_clone_weight_pattern = CallFunction(
154    aten.clone.default,
155    dequantize_per_channel_weight_pattern,
156    memory_format=KeywordArg("memory_format"),
157)
158
159dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
160    aten.clone.default,
161    dequantize_per_channel_to_bf16_weight_pattern,
162    memory_format=KeywordArg("memory_format"),
163)
164
165
166def get_dequantize_qconv_pt2e_pattern(users=1):
167    return CallFunction(
168        torch.ops.onednn.qconv2d_pointwise.default,
169        KeywordArg("x"),
170        KeywordArg("x_scale"),  # x_scale
171        KeywordArg("x_zp"),  # x_zp
172        KeywordArg("packed_weight"),  # packed_weight
173        KeywordArg("w_scale"),  # w_scale
174        KeywordArg("w_zp"),  # w_zp
175        KeywordArg("b"),  # bias
176        KeywordArg("stride"),
177        KeywordArg("padding"),
178        KeywordArg("dilation"),
179        KeywordArg("groups"),
180        KeywordArg("output_scale"),  # output_scale = 1.0
181        KeywordArg("output_zero_point"),  # output_zero_point = 0
182        KeywordArg("output_dtype"),  # output_dtype = None
183        KeywordArg("attr"),  # attr = "none"
184        Arg(),  # scalars
185        Arg(),  # algorithm
186        _users=users,
187    )
188
189
190def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1):
191    qlinear_op = (
192        torch.ops.onednn.qlinear_pointwise.tensor
193        if x_scale_zp_are_tensors
194        else torch.ops.onednn.qlinear_pointwise.default
195    )
196    return CallFunction(
197        qlinear_op,
198        KeywordArg("x"),
199        KeywordArg("x_scale"),
200        KeywordArg("x_zp"),
201        KeywordArg("packed_weight"),
202        KeywordArg("w_scale"),
203        KeywordArg("w_zp"),
204        KeywordArg("b"),
205        KeywordArg("output_scale"),
206        KeywordArg("output_zero_point"),
207        KeywordArg("output_dtype"),
208        KeywordArg("postop_name"),
209        KeywordArg("postop_args"),
210        KeywordArg("postop_algorithm"),
211        _users=users,
212    )
213
214
215dequantize_accum_pattern = CallFunction(
216    quantized_decomposed.dequantize_per_tensor.default,
217    KeywordArg("accum"),
218    KeywordArg("accum_scale"),
219    KeywordArg("accum_zp"),
220    Arg(),
221    Arg(),
222    KeywordArg("accum_dq_dtype"),
223)
224
225
226def generate_pattern_with_binary(
227    binary_post_op,
228    computation_call,
229    extra_input_pattern,
230    dtype_convert=False,
231    swap_inputs=False,
232):
233    binary_pattern = (
234        CallFunction(
235            binary_post_op,
236            extra_input_pattern,
237            computation_call,
238        )
239        if swap_inputs
240        else CallFunction(
241            binary_post_op,
242            computation_call,
243            extra_input_pattern,
244        )
245    )
246    return _may_generate_pattern_with_dtype_convert(
247        binary_pattern,
248        KeywordArg("convert_dtype_after_inplace_add"),
249        dtype_convert,
250    )
251
252
253def generate_pattern_with_unary(computation_call, unary_post_op):
254    if unary_post_op is not None:
255        return CallFunction(
256            unary_post_op,
257            computation_call,
258        )
259    return computation_call
260
261
262def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False):
263    quantized_op_output_pattern_pt2e = CallFunction(
264        quantized_decomposed.quantize_per_tensor.default,
265        _may_generate_pattern_with_dtype_convert(
266            computation_call,
267            Arg(),
268            with_dtype_convert,
269        ),
270        KeywordArg("o_inv_scale"),
271        KeywordArg("o_zp"),
272        KeywordArg("o_qmin"),
273        KeywordArg("o_qmax"),
274        KeywordArg("o_dtype"),
275    )
276    return quantized_op_output_pattern_pt2e
277
278
279def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value):
280    if kwarg_name in check_node.kwargs:
281        actual_value = check_node.kwargs[kwarg_name]
282        return actual_value == expected_value
283    else:
284        assert len(check_node.args) >= (args_index + 1)
285        actual_value = check_node.args[args_index]
286        return actual_value == expected_value
287
288
289def _is_valid_quantized_conv2d_optimization_pattern():
290    def fn(match):
291        output_dtype = _get_pattern_output_dtype(match)
292        if output_dtype in [torch.float32, torch.bfloat16]:
293            # Only keep matched pattern with same output_dtype
294            qconv_node_after_weight_prepack = filter_nodes(
295                match.nodes, torch.ops.onednn.qconv2d_pointwise
296            )[0]
297            return _check_node_kwarg_arg_value(
298                qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype
299            )
300        return True
301
302    return fn
303
304
305def _register_quantized_conv_lowering(
306    pattern,
307    pass_number,
308    computation_op,
309    unary_attr,
310):
311    @register_lowering_pattern(
312        pattern,
313        extra_check=_is_valid_quantized_conv2d_optimization_pattern(),
314        pass_number=pass_number,
315    )
316    def qconv(match: Match, *args, **kwargs):
317        # Activation QParams
318        x, x_scale, x_zp = (
319            kwargs["x"],
320            kwargs["x_scale"],
321            kwargs["x_zp"],
322        )
323        # Weight QParams
324        packed_weight, w_scale, w_zp = (
325            kwargs["packed_weight"],
326            kwargs["w_scale"],
327            kwargs["w_zp"],
328        )
329        # Conv Params
330        b, stride, padding, dilation, groups = (
331            kwargs["b"],
332            kwargs["stride"],
333            kwargs["padding"],
334            kwargs["dilation"],
335            kwargs["groups"],
336        )
337        output_dtype = _get_pattern_output_dtype(match)
338        assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16]
339        # Output QParams
340        o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
341        o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
342        assert (
343            kwargs["attr"] == "none"
344        )  # Expected no post op fused in weight prepack phase
345        if unary_attr.op_name == "hardtanh":
346            min_value = kwargs.get("min_value")
347            max_value = kwargs.get("max_value")
348            unary_attr.scalars_attr = [min_value, max_value]
349
350        computation_args = (
351            x,
352            x_scale,
353            x_zp,
354            packed_weight,
355            w_scale,
356            w_zp,
357            b,
358            stride,
359            padding,
360            dilation,
361            groups,
362            o_inv_scale,
363            o_zero_point,
364            output_dtype,
365            unary_attr.op_name,
366            unary_attr.scalars_attr,
367            unary_attr.algorithm_attr,
368        )
369        counters["inductor"]["qconv2d_unary_matcher_count"] += 1
370        counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes)
371        return L[computation_op](*computation_args)
372
373    return qconv
374
375
376def _is_valid_quantized_linear_optimization_pattern():
377    def fn(match):
378        output_dtype = _get_pattern_output_dtype(match)
379        if output_dtype in [torch.float32, torch.bfloat16]:
380            # Only keep matched pattern with same output_dtype
381            qlinear_node_after_weight_prepack = filter_nodes(
382                match.nodes, torch.ops.onednn.qlinear_pointwise
383            )[0]
384            return _check_node_kwarg_arg_value(
385                qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype
386            )
387        return True
388
389    return fn
390
391
392def _register_quantized_linear_lowering(
393    pattern,
394    pass_number,
395    computation_op,
396    unary_attr,
397):
398    @register_lowering_pattern(
399        pattern,
400        extra_check=_is_valid_quantized_linear_optimization_pattern(),
401        pass_number=pass_number,
402    )
403    def qlinear(match: Match, *args, **kwargs):
404        output_dtype = _get_pattern_output_dtype(match)
405        # Activation QParams
406        x, x_scale, x_zp = (
407            kwargs["x"],
408            kwargs["x_scale"],
409            kwargs["x_zp"],
410        )
411        # Weight QParams
412        packed_weight, w_scale, w_zp = (
413            kwargs["packed_weight"],
414            kwargs["w_scale"],
415            kwargs["w_zp"],
416        )
417
418        # bias
419        b = kwargs["b"] if "b" in kwargs else None
420
421        # Output QParams
422        o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
423        o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
424        assert (
425            kwargs["postop_name"] == "none"
426        )  # Expected no post op fused in weight prepack phase
427
428        computation_args = (
429            x,
430            x_scale,
431            x_zp,
432            packed_weight,
433            w_scale,
434            w_zp,
435            b,
436            o_inv_scale,
437            o_zero_point,
438            output_dtype,
439            unary_attr.op_name,
440            unary_attr.scalars_attr,
441            unary_attr.algorithm_attr,
442        )
443        counters["inductor"]["qlinear_unary_matcher_count"] += 1
444        counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes)
445        return L[computation_op](*computation_args)
446
447    return qlinear
448
449
450def _register_quantized_linear_binary_lowering(
451    pattern,
452    pass_number,
453    computation_op,
454    binary_unary_attr,
455):
456    @register_lowering_pattern(
457        pattern,
458        extra_check=_is_valid_qlinear_binary_optimization_pattern(),
459        pass_number=pass_number,
460    )
461    def qlinear_binary(match: Match, *args, **kwargs):
462        output_dtype = _get_pattern_output_dtype(match)
463        assert output_dtype is not None
464        # Activation QParams
465        x, x_scale, x_zp = (
466            kwargs["x"],
467            kwargs["x_scale"],
468            kwargs["x_zp"],
469        )
470        x2 = (
471            kwargs["accum"]
472            if binary_unary_attr.binary_op_name == "sum"
473            else kwargs["other"]
474        )
475        x2_scale = 1.0
476        x2_zp = 0
477        # Weight QParams
478        packed_weight, w_scale, w_zp = (
479            kwargs["packed_weight"],
480            kwargs["w_scale"],
481            kwargs["w_zp"],
482        )
483        # bias
484        b = kwargs["b"] if "b" in kwargs else None
485        # Output QParams
486        o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
487        o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
488
489        x2.realize()
490        from .mkldnn_fusion import _can_be_inplace
491
492        binary_op_name = binary_unary_attr.binary_op_name
493
494        if binary_op_name == "sum" and not _can_be_inplace(x2):
495            # When we enable the GEMM Template, the output of QLinear
496            # will be reshaped from 2D back to 3D if the input is 3D.
497            # This causes _can_be_inplace(x2) to return False if x2 happens
498            # to be the output of QLinear in this scenario.
499            # Change the post op from sum to binary add for this case.
500            # Refer to test case:
501            #   test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2
502            binary_op_name = "add"
503
504        computation_args = (
505            x,
506            x_scale,
507            x_zp,
508            packed_weight,
509            w_scale,
510            w_zp,
511            x2,
512            b,
513            o_inv_scale,
514            o_zero_point,
515            output_dtype,
516            x2_scale,
517            x2_zp,
518            binary_op_name,
519            binary_unary_attr.alpha,
520            binary_unary_attr.unary_op_name,
521            binary_unary_attr.scalars_attr,
522            binary_unary_attr.algorithm_attr,
523        )
524        counters["inductor"]["qlinear_binary_matcher_count"] += 1
525        counters["inductor"]["qlinear_binary_matcher_nodes"] += len(match.nodes)
526        return L[computation_op](*computation_args)
527
528    return qlinear_binary
529
530
531def _is_valid_qconv_binary_optimization_pattern():
532    return _is_valid_quantized_op_binary_optimization_pattern(
533        torch.ops.onednn.qconv2d_pointwise
534    )
535
536
537def _is_valid_qlinear_binary_optimization_pattern():
538    return _is_valid_quantized_op_binary_optimization_pattern(
539        torch.ops.onednn.qlinear_pointwise,
540        # we don't insert q-dq for extra input due to accuracy issues
541        extra_input_from_dequant=False,
542    )
543
544
545def _is_valid_quantized_op_binary_optimization_pattern(
546    qop, extra_input_from_dequant=True
547):
548    # Check if it's a valid Binary Pattern for qconv2d and qlinear:
549    # * qop_pointwise should only has one users
550    # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern
551    # * the two inputs of binary node should have attribute "meta" and should be tensors
552    # * the two inputs of binary node should have the same shape
553    # * All users of the extra input in this pattern should be
554    #   ancestor nodes of the compute node, except for the binary node
555    #   connected to the compute node.
556    def fn(match):
557        output_dtype = _get_pattern_output_dtype(match)
558        compute_node = filter_nodes(match.nodes, qop)[0]
559        # qop_pointwise should only have one user
560        if len(compute_node.users) != 1:
561            return False
562        binary_node_inputs = next(iter(compute_node.users)).args
563        assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
564        if output_dtype in [torch.float32, torch.bfloat16]:
565            extra_input_of_binary_node = None
566            for arg in binary_node_inputs:
567                if arg != compute_node:
568                    extra_input_of_binary_node = arg
569                    break
570            assert extra_input_of_binary_node is not None
571            # Extra input of binary node comes from dequant pattern
572            if extra_input_from_dequant and (
573                (not isinstance(extra_input_of_binary_node, torch.fx.Node))
574                or (
575                    extra_input_of_binary_node.target
576                    != quantized_decomposed.dequantize_per_tensor.default
577                )
578            ):
579                return False
580
581        # the two inputs of binary node should have attribute "meta" and should be tensors
582        if not (
583            hasattr(binary_node_inputs[0], "meta")
584            and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor)  # type: ignore[union-attr]
585        ) or not (
586            hasattr(binary_node_inputs[1], "meta")
587            and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor)  # type: ignore[union-attr]
588        ):
589            return False
590        # the two inputs of binary node should have the same shape
591        if (
592            binary_node_inputs[0].meta["val"].size()  # type: ignore[union-attr]
593            != binary_node_inputs[1].meta["val"].size()  # type: ignore[union-attr]
594        ):
595            return False
596
597        # All users of the extra input in this pattern should be
598        # ancestor nodes of the compute node, except for the binary node
599        # connected to the compute node.
600
601        from .mkldnn_fusion import _get_remaining_users
602
603        extra_input_of_pattern = (
604            match.kwargs["other"]
605            if "other" in match.kwargs
606            else (
607                match.kwargs["accum"]
608                if output_dtype == torch.uint8 or (not extra_input_from_dequant)
609                else match.kwargs["accum_after_dequant"]
610            )
611        )
612        if (
613            len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1
614            or extra_input_of_pattern == compute_node.args[0]
615        ):
616            return False
617        return True
618
619    return fn
620
621
622def _register_quantized_conv_binary_lowering(
623    pattern,
624    pass_number,
625    computation_op,
626    binary_unary_attr,
627):
628    @register_lowering_pattern(
629        pattern,
630        extra_check=_is_valid_qconv_binary_optimization_pattern(),
631        pass_number=pass_number,
632    )
633    def qconv_binary(match: Match, *args, **kwargs):
634        output_dtype = _get_pattern_output_dtype(match)
635        assert output_dtype is not None
636        x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
637        accum = (
638            kwargs["accum"]
639            if output_dtype == torch.uint8
640            else kwargs["accum_after_dequant"]
641        )
642        accum_scale = kwargs["accum_scale"] if output_dtype == torch.uint8 else 1.0
643        accum_zp = kwargs["accum_zp"] if output_dtype == torch.uint8 else 0
644        packed_weight, w_scale, w_zp = (
645            kwargs["packed_weight"],
646            kwargs["w_scale"],
647            kwargs["w_zp"],
648        )
649        b, stride, padding, dilation, groups = (
650            kwargs["b"],
651            kwargs["stride"],
652            kwargs["padding"],
653            kwargs["dilation"],
654            kwargs["groups"],
655        )
656        # Output QParams
657        o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
658        o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
659
660        accum.realize()
661        from .mkldnn_fusion import _can_be_inplace
662
663        assert _can_be_inplace(
664            accum
665        ), "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
666
667        computation_args = (
668            x,
669            x_scale,
670            x_zp,
671            accum,
672            accum_scale,
673            accum_zp,
674            packed_weight,
675            w_scale,
676            w_zp,
677            b,
678            stride,
679            padding,
680            dilation,
681            groups,
682            o_inv_scale,
683            o_zero_point,
684            output_dtype,
685            binary_unary_attr.binary_op_name,
686            binary_unary_attr.alpha,
687            binary_unary_attr.unary_op_name,
688            binary_unary_attr.scalars_attr,
689            binary_unary_attr.algorithm_attr,
690        )
691        counters["inductor"]["qconv2d_binary_matcher_count"] += 1
692        counters["inductor"]["qconv2d_binary_matcher_nodes"] += len(match.nodes)
693        return L[computation_op](*computation_args)
694
695    return qconv_binary
696
697
698def _register_quantization_unary_fusion():
699    from .mkldnn_fusion import (
700        _gelu_fusion_1 as _gelu_fusion_erf,
701        _gelu_fusion_2 as _gelu_fusion_tanh,
702        _hardswish_fusion,
703        _hardtanh_fusion,
704        _silu_fusion,
705    )
706
707    class UnaryAttr:
708        def __init__(
709            self, op_name: str, scalars_attr=None, algorithm_attr=None
710        ) -> None:
711            self.op_name = op_name
712            self.scalars_attr = scalars_attr if scalars_attr else []
713            self.algorithm_attr = algorithm_attr if algorithm_attr else ""
714
715    for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
716        # QConv2d
717        # Priority 1 to match: QConv2d Unary pattern with int8 output
718        # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
719        # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
720        is_bf16 = original_pattern_output_dtype == torch.bfloat16
721        conv_unary_replace_patterns = {
722            UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
723                get_dequantize_qconv_pt2e_pattern(1),
724            ),
725            UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
726                generate_pattern_with_unary(
727                    get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
728                ),
729            ),
730            UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
731                _unary_fusion_pattern(
732                    _hardtanh_fusion,
733                    get_dequantize_qconv_pt2e_pattern(1),
734                    1,
735                    is_bf16,
736                ),
737                with_dtype_convert=is_bf16,
738            ),
739            UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant(
740                _unary_fusion_pattern(
741                    _hardswish_fusion,
742                    get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
743                    2,
744                    is_bf16,
745                ),
746                with_dtype_convert=is_bf16,
747            ),
748            UnaryAttr("swish", [], ""): generate_pattern_with_output_quant(
749                _unary_fusion_pattern(
750                    _silu_fusion,
751                    get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
752                    2,
753                    is_bf16,
754                ),
755                with_dtype_convert=is_bf16,
756            ),
757        }
758
759        for unary_attr, patterns in conv_unary_replace_patterns.items():
760            # Register qconv2d pattern for ExternKernel Lowering
761            _register_quantized_conv_lowering(
762                patterns,
763                1,  # pass_number
764                torch.ops.onednn.qconv2d_pointwise,  # computation_op
765                unary_attr,  # unary_attr
766            )
767
768        # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
769        conv_unary_replace_float_out_patterns = {
770            UnaryAttr("relu", [], ""): generate_pattern_with_unary(
771                get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
772            ),
773            UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert(
774                _unary_fusion_pattern(
775                    _hardtanh_fusion,
776                    get_dequantize_qconv_pt2e_pattern(1),
777                    1,
778                    is_bf16,
779                ),
780                Arg(),
781                is_bf16,
782            ),
783            UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert(
784                _unary_fusion_pattern(
785                    _hardswish_fusion,
786                    get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
787                    2,
788                    is_bf16,
789                ),
790                Arg(),
791                is_bf16,
792            ),
793            UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert(
794                _unary_fusion_pattern(
795                    _silu_fusion,
796                    get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
797                    2,
798                    is_bf16,
799                ),
800                Arg(),
801                is_bf16,
802            ),
803        }
804
805        for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
806            # Register qconv2d pattern for ExternKernel Lowering
807            _register_quantized_conv_lowering(
808                patterns,
809                2,  # pass_number
810                torch.ops.onednn.qconv2d_pointwise,  # computation_op
811                unary_attr,  # unary_attr
812            )
813
814        # QLinear
815        for x_scale_zp_are_tensors in (False, True):
816            qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
817            # Priority 1 to match: QLinear Unary pattern with int8 output
818            linear_unary_replace_patterns = {
819                UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
820                    qlinear_pattern,
821                ),
822                UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
823                    generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
824                ),
825                UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant(
826                    _unary_fusion_pattern(
827                        _gelu_fusion_erf,
828                        get_qlinear_pt2e_pattern(
829                            x_scale_zp_are_tensors, 1 if is_bf16 else 2
830                        ),
831                        2,
832                        is_bf16,
833                    ),
834                    with_dtype_convert=is_bf16,
835                ),
836                UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant(
837                    _unary_fusion_pattern(
838                        _gelu_fusion_tanh,
839                        get_qlinear_pt2e_pattern(
840                            x_scale_zp_are_tensors, 1 if is_bf16 else 4
841                        ),
842                        4,
843                        is_bf16,
844                    ),
845                    with_dtype_convert=is_bf16,
846                ),
847            }
848
849            for unary_attr, patterns in linear_unary_replace_patterns.items():
850                _register_quantized_linear_lowering(
851                    patterns,
852                    1,  # pass_number
853                    torch.ops.onednn.qlinear_pointwise,  # computation_op
854                    unary_attr,  # unary_attr
855                )
856
857            # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
858            linear_unary_replace_float_out_patterns = {
859                UnaryAttr("relu", [], ""): generate_pattern_with_unary(
860                    qlinear_pattern, aten.relu.default
861                ),
862                UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert(
863                    _unary_fusion_pattern(
864                        _gelu_fusion_erf,
865                        get_qlinear_pt2e_pattern(
866                            x_scale_zp_are_tensors, 1 if is_bf16 else 2
867                        ),
868                        2,
869                        is_bf16,
870                    ),
871                    Arg(),
872                    is_bf16,
873                ),
874                UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert(
875                    _unary_fusion_pattern(
876                        _gelu_fusion_tanh,
877                        get_qlinear_pt2e_pattern(
878                            x_scale_zp_are_tensors, 1 if is_bf16 else 4
879                        ),
880                        4,
881                        is_bf16,
882                    ),
883                    Arg(),
884                    is_bf16,
885                ),
886            }
887
888            for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
889                _register_quantized_linear_lowering(
890                    patterns,
891                    2,  # pass_number
892                    torch.ops.onednn.qlinear_pointwise,  # computation_op
893                    unary_attr,  # unary_attr
894                )
895
896
897def _register_quantization_binary_fusion():
898    class BinaryUnaryAttr:
899        def __init__(
900            self,
901            binary_op_name: str,
902            alpha=None,
903            unary_op_name: str = "none",
904            scalars_attr=None,
905            algorithm_attr=None,
906        ) -> None:
907            self.binary_op_name = binary_op_name
908            self.alpha = alpha if alpha else 1.0
909            self.unary_op_name = unary_op_name
910            self.scalars_attr = scalars_attr if scalars_attr else []
911            self.algorithm_attr = algorithm_attr if algorithm_attr else ""
912
913    for int8_mixed_bf16_with_inplace_add in [False, True]:
914        # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
915        binary_replace_patterns = {
916            BinaryUnaryAttr(
917                "sum", 1.0, "none", [], ""
918            ): generate_pattern_with_output_quant(
919                generate_pattern_with_binary(
920                    aten.add.Tensor,
921                    get_dequantize_qconv_pt2e_pattern(1),
922                    dequantize_accum_pattern,
923                    int8_mixed_bf16_with_inplace_add,
924                ),
925            ),
926            BinaryUnaryAttr(
927                "sum", 1.0, "relu", [], ""
928            ): generate_pattern_with_output_quant(
929                generate_pattern_with_unary(
930                    generate_pattern_with_binary(
931                        aten.add.Tensor,
932                        get_dequantize_qconv_pt2e_pattern(1),
933                        dequantize_accum_pattern,
934                        int8_mixed_bf16_with_inplace_add,
935                    ),
936                    aten.relu.default,
937                ),
938            ),
939        }
940
941        for binary_unary_attr, patterns in binary_replace_patterns.items():
942            _register_quantized_conv_binary_lowering(
943                patterns,
944                0,  # pass_number
945                torch.ops.onednn.qconv2d_pointwise.binary,  # computation_op
946                binary_unary_attr,  # binary_unary_attr
947            )
948
949        # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output
950        binary_replace_float_out_patterns = {
951            BinaryUnaryAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary(
952                generate_pattern_with_binary(
953                    aten.add.Tensor,
954                    get_dequantize_qconv_pt2e_pattern(1),
955                    KeywordArg("accum_after_dequant"),
956                    int8_mixed_bf16_with_inplace_add,
957                ),
958                aten.relu.default,
959            ),
960        }
961
962        for (
963            binary_unary_attr,
964            patterns,
965        ) in binary_replace_float_out_patterns.items():
966            if int8_mixed_bf16_with_inplace_add:
967                _register_quantized_conv_binary_lowering(
968                    patterns,
969                    0,  # pass_number
970                    torch.ops.onednn.qconv2d_pointwise.binary,  # computation_op
971                    binary_unary_attr,  # binary_unary_attr
972                )
973            else:
974                _register_quantized_conv_binary_lowering(
975                    patterns,
976                    1,  # pass_number
977                    torch.ops.onednn.qconv2d_pointwise.binary,  # computation_op
978                    binary_unary_attr,  # binary_unary_attr
979                )
980
981        # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output
982        binary_replace_float_out_patterns = {
983            BinaryUnaryAttr("sum", 1.0, "none", [], ""): generate_pattern_with_binary(
984                aten.add.Tensor,
985                get_dequantize_qconv_pt2e_pattern(1),
986                KeywordArg("accum_after_dequant"),
987                int8_mixed_bf16_with_inplace_add,
988            ),
989        }
990
991        for (
992            binary_unary_attr,
993            patterns,
994        ) in binary_replace_float_out_patterns.items():
995            _register_quantized_conv_binary_lowering(
996                patterns,
997                1 if int8_mixed_bf16_with_inplace_add else 2,  # pass_number
998                torch.ops.onednn.qconv2d_pointwise.binary,  # computation_op
999                binary_unary_attr,  # binary_unary_attr
1000            )
1001
1002    # QLinear
1003    r"""
1004    Supported linear-binary(-unary) patterns
1005
1006        linear(X)   extra input
1007               \   /
1008                Add
1009                 |
1010            Optional(relu)
1011                 |
1012                 Y
1013
1014    1. int8-mixed-fp32
1015    +---+---------------+-----------+------------------------------+---------+
1016    | # | Add type      | Quant out | Pattern                      | Post op |
1017    +---+---------------+-----------+------------------------------+---------+
1018    | 1 | In-/out-place | Yes       | linear + fp32 -> (relu) -> q | add     |
1019    +---+---------------+-----------+------------------------------+---------+
1020    | 2 | In-/out-place | No        | linear + fp32 -> (relu)      | sum     |
1021    +---+---------------+-----------+------------------------------+---------+
1022
1023    2. int8-mixed-bf16
1024    +---+----------+---------------+-----------+-----------------------------------------+---------+
1025    | # | X2 dtype | Add type      | Quant out | Pattern                                 | Post op |
1026    +---+----------+---------------+-----------+-----------------------------------------+---------+
1027    | 1 | BF16     | In-/out-place | Yes       | linear + bf16 -> (relu) -> q            | add     |
1028    +---+----------+---------------+-----------+-----------------------------------------+---------+
1029    | 2 | BF16     | In-/out-place | No        | linear + bf16 -> (relu)                 | sum     |
1030    +---+----------+---------------+-----------+-----------------------------------------+---------+
1031    | 3 | FP32     | Out-place     | Yes       | linear + fp32 -> (relu) -> q            | add     |
1032    |   |          | In-place right|           |                                         |         |
1033    +---+----------+---------------+-----------+-----------------------------------------+---------+
1034    | 4 | FP32     | Out-place     | No        | linear + fp32 -> (relu)                 | sum     |
1035    |   |          | In-place right|           |                                         |         |
1036    +---+----------+---------------+-----------+-----------------------------------------+---------+
1037    | 5 | FP32     | In-place left | Yes       | linear + fp32 -> to_bf16 -> (relu) -> q | add     |
1038    +---+----------+---------------+-----------+-----------------------------------------+---------+
1039    | 6 | FP32     | In-place left | No        | linear + fp32 -> to_bf16 -> (relu)      | add     |
1040    +---+----------+---------------+-----------+-----------------------------------------+---------+
1041
1042    Note
1043    (1) The positions of linear and the extra input can be swapped.
1044    (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the
1045    extra input, we don't match that pattern because we cannot match all these patterns in 3 passes.
1046    """
1047    for x_scale_zp_are_tensors in (False, True):
1048        qlinear_binary_op = (
1049            torch.ops.onednn.qlinear_pointwise.binary_tensor
1050            if x_scale_zp_are_tensors
1051            else torch.ops.onednn.qlinear_pointwise.binary
1052        )
1053        unary_postop_list = ["none", "relu"]
1054        unary_postop_dict = {
1055            "none": None,
1056            "relu": aten.relu.default,
1057        }
1058        convert_dtype_after_binary_list = [False, True]
1059
1060        # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output
1061        # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16,
1062        # totally 3 patterns (2 are identical)
1063        swap_binary_inputs_list = [False, True]
1064        int8_mixed_bf16_list = [False, True]
1065        combinations = itertools.product(
1066            unary_postop_list,
1067            int8_mixed_bf16_list,
1068            swap_binary_inputs_list,
1069            convert_dtype_after_binary_list,
1070        )
1071        qlinear_binary_replace_patterns = {}
1072        for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations:
1073            if not int8_mixed_bf16 and cvt_dtype_binary:
1074                # No convert node after binary node if dtypes are all fp32
1075                continue
1076            qlinear_binary_replace_patterns.update(
1077                {
1078                    BinaryUnaryAttr(
1079                        "add", 1.0, unary_op, [], ""
1080                    ): generate_pattern_with_output_quant(
1081                        generate_pattern_with_unary(
1082                            generate_pattern_with_binary(
1083                                aten.add.Tensor,
1084                                get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
1085                                KeywordArg("other"),
1086                                # If fp32 extra input is inplace added to bf16 linear output,
1087                                # a to_bf16 node is inserted after binary
1088                                dtype_convert=cvt_dtype_binary,
1089                                swap_inputs=swap_inputs,
1090                            ),
1091                            unary_postop_dict[unary_op],
1092                        ),
1093                    )
1094                }
1095            )
1096        for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items():
1097            _register_quantized_linear_binary_lowering(
1098                patterns,
1099                0,  # pass_number
1100                qlinear_binary_op,  # computation_op
1101                binary_unary_attr,  # binary_unary_attr
1102            )
1103
1104        # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
1105        # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
1106        # totally 2 patterns (2 are identical)
1107        binary_replace_float_out_patterns = {}
1108        for swap_binary_inputs in swap_binary_inputs_list:
1109            binary_replace_float_out_patterns.update(
1110                {
1111                    BinaryUnaryAttr(
1112                        "sum", 1.0, "relu", [], ""
1113                    ): generate_pattern_with_unary(
1114                        generate_pattern_with_binary(
1115                            aten.add.Tensor,
1116                            get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
1117                            KeywordArg("accum"),
1118                            dtype_convert=False,
1119                            swap_inputs=swap_binary_inputs,
1120                        ),
1121                        aten.relu.default,
1122                    ),
1123                }
1124            )
1125        for (
1126            binary_unary_attr,
1127            patterns,
1128        ) in binary_replace_float_out_patterns.items():
1129            _register_quantized_linear_binary_lowering(
1130                patterns,
1131                1,  # pass_number
1132                qlinear_binary_op,  # computation_op
1133                binary_unary_attr,
1134            )
1135        # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
1136        # Covers case (6) of int8-mixed-bf16
1137        binary_replace_float_out_patterns = {}
1138        for swap_binary_inputs in swap_binary_inputs_list:
1139            binary_replace_float_out_patterns.update(
1140                {
1141                    BinaryUnaryAttr(
1142                        "add", 1.0, "relu", [], ""
1143                    ): generate_pattern_with_unary(
1144                        generate_pattern_with_binary(
1145                            aten.add.Tensor,
1146                            get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
1147                            KeywordArg("other"),
1148                            dtype_convert=True,
1149                            swap_inputs=swap_binary_inputs,
1150                        ),
1151                        aten.relu.default,
1152                    ),
1153                }
1154            )
1155        for (
1156            binary_unary_attr,
1157            patterns,
1158        ) in binary_replace_float_out_patterns.items():
1159            _register_quantized_linear_binary_lowering(
1160                patterns,
1161                1,  # pass_number
1162                qlinear_binary_op,  # computation_op
1163                binary_unary_attr,
1164            )
1165
1166        # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output
1167        # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
1168        # totally 2 patterns (2 are identical)
1169        binary_replace_float_out_patterns = {}
1170        for swap_binary_inputs in swap_binary_inputs_list:
1171            binary_replace_float_out_patterns.update(
1172                {
1173                    BinaryUnaryAttr(
1174                        "sum", 1.0, "none", [], ""
1175                    ): generate_pattern_with_binary(
1176                        aten.add.Tensor,
1177                        get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
1178                        KeywordArg("accum"),
1179                        dtype_convert=False,
1180                        swap_inputs=swap_binary_inputs,
1181                    ),
1182                }
1183            )
1184        for (
1185            binary_unary_attr,
1186            patterns,
1187        ) in binary_replace_float_out_patterns.items():
1188            _register_quantized_linear_binary_lowering(
1189                patterns,
1190                2,  # pass_number
1191                qlinear_binary_op,  # computation_op
1192                binary_unary_attr,
1193            )
1194        # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output
1195        # Covers (6) of int8-mixed-bf16
1196        binary_replace_float_out_patterns = {}
1197        for swap_binary_inputs in swap_binary_inputs_list:
1198            binary_replace_float_out_patterns.update(
1199                {
1200                    BinaryUnaryAttr(
1201                        "add", 1.0, "none", [], ""
1202                    ): generate_pattern_with_binary(
1203                        aten.add.Tensor,
1204                        get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
1205                        KeywordArg("other"),
1206                        dtype_convert=True,
1207                        swap_inputs=swap_binary_inputs,
1208                    ),
1209                }
1210            )
1211        for (
1212            binary_unary_attr,
1213            patterns,
1214        ) in binary_replace_float_out_patterns.items():
1215            _register_quantized_linear_binary_lowering(
1216                patterns,
1217                2,  # pass_number
1218                qlinear_binary_op,  # computation_op
1219                binary_unary_attr,
1220            )
1221
1222
1223def _is_valid_quantized_maxpool2d_optimization_pattern():
1224    def fn(match):
1225        # Only match the pattern which max_pool2d_with_indices returns value
1226        # instead of indices.
1227        get_item_node = filter_nodes(match.nodes, operator.getitem)[0]
1228        return get_item_node.args[1] == 0
1229
1230    return fn
1231
1232
1233def _register_quantized_maxpool2d_lowering(
1234    pattern,
1235    computation_op,
1236):
1237    @register_lowering_pattern(
1238        pattern,
1239        extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(),
1240    )
1241    def qmaxpool2d(match: Match, *args, **kwargs):
1242        x = kwargs["x"]
1243        kernel_size = kwargs["kernel_size"]
1244        stride = kwargs["stride"] if ("stride" in kwargs) else None
1245        padding = kwargs["padding"] if ("padding" in kwargs) else 0
1246        dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
1247        ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
1248
1249        if padding == 0:
1250            padding = [0, 0]
1251        if dilation == 1:
1252            dilation = [1, 1]
1253        if not stride:
1254            stride = kernel_size
1255        kernel_size = pad_listlike(kernel_size, 2)
1256        stride = pad_listlike(stride, 2)
1257        padding = pad_listlike(padding, 2)
1258        dilation = pad_listlike(dilation, 2)
1259
1260        assert len(kernel_size) == 2
1261        assert len(stride) == 2
1262        assert len(padding) == 2
1263        assert len(dilation) == 2
1264
1265        computation_args = (
1266            x,
1267            kernel_size,
1268            stride,
1269            padding,
1270            dilation,
1271            ceil_mode,
1272        )
1273        computation_args, _ = require_channels_last(computation_op, *computation_args)
1274        counters["inductor"]["qmaxpool2d_matcher_count"] += 1
1275        counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes)
1276        return L[computation_op](*computation_args)
1277
1278    return qmaxpool2d
1279
1280
1281def _register_quantization_maxpool2d():
1282    # Currently, the default parameters are not in FX Graph generated by Dynamo export.
1283    # So, if user defines nn.MaxPool2d with different assignment of default parameter,
1284    # it will generate graph with different number of input nodes and hence
1285    # different pattern to be matched.
1286    # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901
1287    max_pool2d_args_list = [
1288        [
1289            KeywordArg("stride"),
1290        ],
1291        [
1292            KeywordArg("stride"),
1293            KeywordArg("padding"),
1294        ],
1295        [
1296            KeywordArg("stride"),
1297            KeywordArg("padding"),
1298            KeywordArg("dilation"),
1299        ],
1300        [
1301            KeywordArg("stride"),
1302            KeywordArg("padding"),
1303            KeywordArg("dilation"),
1304            KeywordArg("ceil_mode"),
1305        ],
1306    ]
1307    for max_pool2d_args in max_pool2d_args_list:
1308        dequantize_maxpool2d_pattern = CallFunction(
1309            aten.max_pool2d_with_indices.default,
1310            get_dequantize_per_tensor_activation_pattern(),
1311            KeywordArg("kernel_size"),
1312            *max_pool2d_args,
1313        )
1314        dequantize_lowmem_maxpool2d_pattern = CallFunction(
1315            prims._low_memory_max_pool2d_with_offsets.default,
1316            get_dequantize_per_tensor_activation_pattern(),
1317            KeywordArg("kernel_size"),
1318            *max_pool2d_args,
1319            KeywordArg("offset_dtype"),
1320        )
1321        dequantize_maxpool2d_get_item_pattern = CallFunction(
1322            operator.getitem,
1323            dequantize_maxpool2d_pattern,
1324            Arg(),
1325        )
1326        dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction(
1327            operator.getitem,
1328            dequantize_lowmem_maxpool2d_pattern,
1329            Arg(),
1330        )
1331        _register_quantized_maxpool2d_lowering(
1332            generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
1333            quantized.max_pool2d.default,
1334        )
1335        _register_quantized_maxpool2d_lowering(
1336            generate_pattern_with_output_quant(
1337                dequantize_lowmem_maxpool2d_get_item_pattern
1338            ),
1339            quantized.max_pool2d.default,
1340        )
1341
1342
1343def _is_input_output_same_scale_zp(check_node):
1344    def fn(match):
1345        # Ensure all the inputs and output has same scale and zero point
1346        # Step 1: Check inputs/output zero point
1347        # Get dequant nodes at input
1348        dequant_nodes = filter_nodes(
1349            match.nodes, quantized_decomposed.dequantize_per_tensor.default
1350        )
1351        zero_points = [node.args[2] for node in dequant_nodes]
1352        # Get quant nodes at output
1353        quant_nodes = filter_nodes(
1354            match.nodes, quantized_decomposed.quantize_per_tensor.default
1355        )
1356        assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern"
1357        zero_points.append(quant_nodes[0].args[2])
1358        if not all(zero_point == zero_points[0] for zero_point in zero_points):
1359            return False
1360
1361        # Step 2: Check inputs/output scale
1362        scales = [node.args[1] for node in dequant_nodes]
1363        scales.append(quant_nodes[0].args[1])
1364        if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales):  # type: ignore[arg-type]
1365            return False
1366
1367        return True
1368
1369    return fn
1370
1371
1372def _register_quantized_cat_lowering(
1373    pattern,
1374    computation_op,
1375):
1376    @register_lowering_pattern(
1377        pattern,
1378        extra_check=_is_input_output_same_scale_zp(aten.cat.default),
1379    )
1380    def qcat(match: Match, inputs, dim, **kwargs):
1381        # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...]
1382        uint8_inputs = [input[0] for input in inputs]
1383        counters["inductor"]["qcat_matcher_count"] += 1
1384        counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes)
1385        return L[computation_op](uint8_inputs, dim)
1386
1387    return qcat
1388
1389
1390_raw_dequantize_per_tensor_activation_pattern = CallFunction(
1391    quantized_decomposed.dequantize_per_tensor.default,
1392    Arg(),
1393    Arg(),
1394    Arg(),
1395    Arg(),
1396    Arg(),
1397    Arg(),
1398)
1399
1400
1401def _register_quantization_cat():
1402    dequantize_cat_pattern = CallFunction(
1403        aten.cat.default,
1404        ListOf(_raw_dequantize_per_tensor_activation_pattern),
1405        KeywordArg("dim"),
1406    )
1407    _register_quantized_cat_lowering(
1408        generate_pattern_with_output_quant(dequantize_cat_pattern),
1409        aten.cat,
1410    )
1411
1412
1413def _register_quantized_reshape_lowering(
1414    pattern,
1415    computation_op,
1416):
1417    @register_lowering_pattern(
1418        pattern,
1419        extra_check=_is_input_output_same_scale_zp(aten.reshape.default),
1420    )
1421    def qreshape(match: Match, *args, **kwargs):
1422        qx = kwargs["x"]
1423        shape = kwargs["shape"]
1424        counters["inductor"]["qreshape_matcher_count"] += 1
1425        counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes)
1426        return L[computation_op](qx, shape)
1427
1428    return qreshape
1429
1430
1431def _register_quantization_reshape():
1432    dequantize_reshape_pattern = CallFunction(
1433        torch.ops.aten.reshape.default,
1434        get_dequantize_per_tensor_activation_pattern(),
1435        KeywordArg("shape"),
1436    )
1437    _register_quantized_reshape_lowering(
1438        generate_pattern_with_output_quant(dequantize_reshape_pattern),
1439        aten.reshape,
1440    )
1441
1442
1443def _is_valid_woq_optimization_pattern():
1444    def fn(match):
1445        assert all(k in match.kwargs for k in ("x", "weight", "scales"))
1446        x = match.kwargs["x"].meta["val"]
1447        weight = match.kwargs["weight"].meta["val"]
1448        scales = match.kwargs["scales"].meta["val"]
1449        return (
1450            # For now, we only support woq mm kernels
1451            # with x.type=bfloat16 and w.type=int8
1452            x.dtype == torch.bfloat16
1453            and weight.dtype == torch.int8
1454            and scales.dtype == torch.bfloat16
1455            # _weight_int8pack_mm kernel only supports cpu now
1456            # TODO: add cuda kernel support instead of calling mul+sum
1457            and x.device.type == "cpu"
1458            and x.device == weight.device
1459            and x.device == scales.device
1460        )
1461
1462    return fn
1463
1464
1465def _register_woq_lowering(pattern, computation_woq, computation_reshape):
1466    @register_lowering_pattern(
1467        pattern,
1468        extra_check=_is_valid_woq_optimization_pattern(),
1469    )
1470    def woq(match: Match, *args, **kwargs):
1471        x = kwargs["x"]
1472        weight = kwargs["weight"]
1473        scales = kwargs["scales"]
1474        counters["inductor"]["woq_matcher_count"] += 1
1475        counters["inductor"]["woq_matcher_nodes"] += len(match.nodes)
1476        out_features = weight.get_size()[0]
1477        origin_x_size = x.get_size()
1478        x_shape = [-1, origin_x_size[-1]]
1479        out_shape = origin_x_size[:-1] + [
1480            out_features,
1481        ]
1482        func1 = L[computation_reshape](x, x_shape)
1483        func2 = L[computation_woq](func1, weight, scales)
1484        return L[computation_reshape](func2, out_shape)
1485
1486    return woq
1487
1488
1489def _register_woq_mm_int8_pattern1():
1490    # F.linear(x, weight.to(dtype=x.dtype)) * scales
1491    # case of dispatching to mm, with x reshape
1492    _woq_pattern = CallFunction(
1493        aten.mul.Tensor,
1494        CallFunction(
1495            aten.reshape.default,
1496            CallFunction(
1497                aten.mm.default,
1498                CallFunction(aten.reshape.default, KeywordArg("x"), Arg()),
1499                CallFunction(
1500                    aten.permute.default,
1501                    CallFunction(
1502                        prims.convert_element_type.default, KeywordArg("weight"), Arg()
1503                    ),
1504                    Arg(),
1505                ),
1506            ),
1507            Arg(),
1508        ),
1509        KeywordArg("scales"),
1510    )
1511    _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
1512
1513
1514def _register_woq_mm_int8_pattern2():
1515    # F.linear(x, weight.to(dtype=x.dtype)) * scales
1516    # case of dispatching to mm, w/o x reshape
1517    _woq_pattern = CallFunction(
1518        aten.mul.Tensor,
1519        CallFunction(
1520            aten.reshape.default,
1521            CallFunction(
1522                aten.mm.default,
1523                KeywordArg("x"),
1524                CallFunction(
1525                    aten.permute.default,
1526                    CallFunction(
1527                        prims.convert_element_type.default, KeywordArg("weight"), Arg()
1528                    ),
1529                    Arg(),
1530                ),
1531            ),
1532            Arg(),
1533        ),
1534        KeywordArg("scales"),
1535    )
1536    _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
1537
1538
1539def _register_woq_mm_int8_pattern3():
1540    # F.linear(x, weight.to(dtype=x.dtype)) * scales
1541    # case of dispatching to bmm
1542    _woq_pattern = CallFunction(
1543        aten.mul.Tensor,
1544        CallFunction(
1545            aten.bmm.default,
1546            CallFunction(aten.expand.default, KeywordArg("x"), Arg()),
1547            CallFunction(
1548                aten.expand.default,
1549                CallFunction(
1550                    aten.permute.default,
1551                    CallFunction(
1552                        prims.convert_element_type.default, KeywordArg("weight"), Arg()
1553                    ),
1554                    Arg(),
1555                ),
1556                Arg(),
1557            ),
1558        ),
1559        KeywordArg("scales"),
1560    )
1561    _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
1562
1563
1564def _register_quantization_lowerings():
1565    _register_quantization_unary_fusion()
1566    _register_quantization_binary_fusion()
1567    _register_quantization_maxpool2d()
1568    _register_quantization_cat()
1569    _register_quantization_reshape()
1570
1571
1572def _register_woq_lowerings():
1573    _register_woq_mm_int8_pattern1()
1574    _register_woq_mm_int8_pattern2()
1575    _register_woq_mm_int8_pattern3()
1576
1577
1578def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
1579    def _inner(match):
1580        assert dtype in [torch.float32, torch.bfloat16]
1581        dequant_pattern_end_node = match.output_node()
1582        if dequant_pattern_end_node.target not in [
1583            quantized_decomposed.dequantize_per_tensor.default,
1584            quantized_decomposed.dequantize_per_tensor.tensor,
1585            prims.convert_element_type.default,
1586            aten.reshape.default,
1587        ]:
1588            return False
1589
1590        if dequant_pattern_end_node.target is aten.reshape.default:
1591            dequant_node = (
1592                dequant_pattern_end_node.args[
1593                    0
1594                ]  # pattern: linear <- reshape <- dequant
1595                if dtype == torch.float32
1596                else dequant_pattern_end_node.args[0].args[
1597                    0
1598                ]  # pattern: linear <- reshape <- to_bf16 <- dequant
1599            )
1600        else:
1601            dequant_node = (
1602                dequant_pattern_end_node  # pattern: linear <- dequant
1603                if dtype == torch.float32
1604                else dequant_pattern_end_node.args[
1605                    0
1606                ]  # pattern: linear <- to_bf16 <- dequant
1607            )
1608
1609        if (
1610            dequant_node.target
1611            in [
1612                quantized_decomposed.dequantize_per_tensor.default,
1613                quantized_decomposed.dequantize_per_tensor.tensor,
1614            ]
1615            and len(list(dequant_pattern_end_node.users)) > 1
1616        ):
1617            # If dequant pattern has more than 1 users, then do dequant promoted
1618            return True
1619        return False
1620
1621    return _inner
1622
1623
1624def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
1625    @register_freezing_graph_pattern(
1626        pattern,
1627        extra_check=_is_valid_dequant_promotion_pattern(dtype),
1628        pass_number=pass_number,
1629    )
1630    def dequant_promotion(match: Match, *args, **kwargs):
1631        # Dequant_promotion will transform
1632        # graph 1:
1633        #            quant
1634        #      + - - - | - - - +
1635        #      |    dequant    |
1636        #      |    /     \    |
1637        #      |  node1  node2 |
1638        #      + - | - - - | - +
1639        #        quant   quant
1640        # into:
1641        # graph 2:
1642        #            quant
1643        #      + - - / - \ - - +
1644        #      |dequant dequant|
1645        #      |    |      |   |
1646        #      | node1 node2   |
1647        #      + - | - - - | - +
1648        #        quant   quant
1649        # In graph 1, the dequant node is shared by node1 and node2,
1650        # as a result, neither node1 nor node2 could form an int8
1651        # fusion pattern.
1652        # After this transformation, the graph 2 could hit the int8
1653        # fusion pattern: dequant-node-quant, respectively for
1654        # node1 and node2.
1655        assert dtype in [torch.float32, torch.bfloat16]
1656
1657        def clone_to_new_node(graph, source_node, user_node):
1658            # Clone the source_node to a new node
1659            # Replace user_node's input from source_node to new_node
1660            assert (
1661                source_node.op == "call_function"
1662            ), "clone_to_new_node only support node.op call_function"
1663            with graph.inserting_before(user_node):
1664                new_node = graph.call_function(
1665                    source_node.target,
1666                    args=source_node.args,
1667                    kwargs=source_node.kwargs,
1668                )
1669                new_node.meta = copy.copy(source_node.meta)
1670                user_node.replace_input_with(source_node, new_node)
1671            return new_node
1672
1673        # Find the start node and end node of a dequant pattern
1674        # * End node should be the match.output_node()
1675        # * Start node should be the node of dequantize_per_tensor
1676        dequant_pattern_end_node = match.output_node()
1677        assert dequant_pattern_end_node.target in [
1678            quantized_decomposed.dequantize_per_tensor.default,
1679            quantized_decomposed.dequantize_per_tensor.tensor,
1680            prims.convert_element_type.default,
1681            aten.reshape.default,
1682        ]
1683
1684        # For a dequant pattern, we should expect see the node list as:
1685        # * OPT(aten.reshape.default)
1686        # * OPT(prims.convert_element_type.default) (to_bf16)
1687        # * dequantize_per_tensor
1688        def _find_first_node_in_dequant_pattern(_node):
1689            if _node.target in [
1690                quantized_decomposed.dequantize_per_tensor.default,
1691                quantized_decomposed.dequantize_per_tensor.tensor,
1692            ]:
1693                # For a dequant pattern, we expect the start node is a dequantize_per_tensor node
1694                return _node
1695            else:
1696                assert (
1697                    len(_node.args) >= 1
1698                ), "In in dequant pattern, each node should have more than 1 arg."
1699                return _find_first_node_in_dequant_pattern(_node.args[0])
1700
1701        dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
1702            dequant_pattern_end_node
1703        )
1704
1705        assert dequant_pattern_start_node.target in [
1706            quantized_decomposed.dequantize_per_tensor.default,
1707            quantized_decomposed.dequantize_per_tensor.tensor,
1708        ]
1709
1710        # Clone the dequant pattern for each user node
1711        graph = match.graph
1712        user_node_list = list(dequant_pattern_end_node.users)
1713        for user_node in user_node_list[1:]:
1714            _source_node = dequant_pattern_end_node
1715            _user_node = user_node
1716            while _source_node != dequant_pattern_start_node.args[0]:
1717                _user_node = clone_to_new_node(graph, _source_node, _user_node)
1718                _source_node = _source_node.args[0]  # type: ignore[assignment]
1719
1720        counters["inductor"]["dequant_promotion_matcher_count"] += 1
1721        counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes)
1722
1723
1724def _is_valid_dequant_conv2d_pattern(dtype):
1725    def _inner(match):
1726        # Here we do some further check to ensure:
1727        # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now.
1728        # 2. The dequant pattern has only 1 user of conv2d node.
1729        # If these conditions don't meet, we will not
1730        # insert weight prepack node into the matched pattern.
1731        conv_node = match.output_node()
1732        assert conv_node.target is aten.convolution.default
1733        input_meta_value = conv_node.args[0].meta.get("val")
1734        weight_meta_value = conv_node.args[1].meta.get("val")
1735        for meta_value in [input_meta_value, weight_meta_value]:
1736            if (
1737                meta_value is None
1738                or meta_value.device.type != "cpu"
1739                or meta_value.dim() != 4
1740            ):
1741                # Only support conv2d now
1742                return False
1743
1744        assert dtype in [torch.float32, torch.bfloat16]
1745
1746        if dtype == torch.float32:
1747            dequant_node = conv_node.args[0]
1748        else:
1749            convert_to_bf16 = conv_node.args[0]
1750            dequant_node = convert_to_bf16.args[0]
1751
1752        if len(list(dequant_node.users)) != 1:
1753            # Ensure the dequant pattern only has 1 user
1754            # since we will delete the dequant pattern here
1755            return False
1756        return True
1757
1758    return _inner
1759
1760
1761def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32):
1762    @register_freezing_graph_pattern(
1763        pattern,
1764        extra_check=_is_valid_dequant_conv2d_pattern(dtype),
1765        pass_number=pass_number,
1766    )
1767    def qconv_weight_prepack(match: Match, *args, **kwargs):
1768        """
1769        Match the pattern:
1770        int8 activation
1771          |
1772        dequant_per_tensor
1773          |
1774        Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
1775
1776        Insert weight prepack node and change the pattern to:
1777        int8 activation
1778          |
1779        onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight
1780        """
1781        assert dtype in [torch.float32, torch.bfloat16]
1782        conv_node = match.output_node()
1783        assert conv_node.target is aten.convolution.default
1784        if dtype == torch.float32:
1785            dequant_node = conv_node.args[0]
1786        else:
1787            convert_to_bf16 = conv_node.args[0]
1788            dequant_node = convert_to_bf16.args[0]  # type: ignore[union-attr]
1789        has_clone_to_channel_last_node_in_pattern = (
1790            conv_node.args[1].target is aten.clone.default  # type: ignore[union-attr]
1791        )
1792        clone_node = (
1793            conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None
1794        )
1795
1796        if dtype == torch.float32:
1797            dequant_per_channel = (
1798                clone_node.args[0]  # type: ignore[union-attr]
1799                if has_clone_to_channel_last_node_in_pattern
1800                else conv_node.args[1]
1801            )
1802        else:
1803            weight_to_bf16_node = (
1804                clone_node.args[0]  # type: ignore[union-attr]
1805                if has_clone_to_channel_last_node_in_pattern
1806                else conv_node.args[1]
1807            )
1808            dequant_per_channel = weight_to_bf16_node.args[0]  # type: ignore[union-attr]
1809
1810        assert (
1811            dequant_per_channel.target  # type: ignore[union-attr]
1812            is quantized_decomposed.dequantize_per_channel.default
1813        )
1814
1815        # Activation QParams
1816        qx, x_zp, x_scale = (
1817            kwargs["x"],
1818            kwargs["x_zp"],
1819            kwargs["x_scale"],
1820        )
1821
1822        # Weight QParams
1823        qw, w_scale, w_zp = (
1824            kwargs["q_weight"],
1825            kwargs["w_scale"],
1826            kwargs["w_zp"],
1827        )
1828
1829        # Conv Params
1830        bias, stride, padding, dilation, groups = (
1831            kwargs["b"],
1832            kwargs["stride"],
1833            kwargs["padding"],
1834            kwargs["dilation"],
1835            kwargs["groups"],
1836        )
1837
1838        x_shape = qx.meta.get("tensor_meta").shape
1839        if has_free_symbols(x_shape):
1840            # For dynamic shape case, we can't get activation shape ahead of runtime.
1841            x_shape = None
1842        graph = match.graph
1843        with graph.inserting_before(conv_node):
1844            # Insert weight prepack node and the QConv node
1845            packed_weight_inputs = (
1846                qw,
1847                w_scale,
1848                x_scale,
1849                x_zp,
1850                stride,
1851                padding,
1852                dilation,
1853                groups,
1854                x_shape,
1855            )
1856            packed_weight_op = torch.ops.onednn.qconv_prepack
1857            prepack_weight_node = graph.call_function(
1858                packed_weight_op, args=packed_weight_inputs
1859            )
1860
1861            new_args: Tuple[Any, ...] = (
1862                qx,
1863                x_scale,
1864                x_zp,
1865                prepack_weight_node,
1866                w_scale,
1867                w_zp,
1868                bias,
1869                stride,
1870                padding,
1871                dilation,
1872                groups,
1873                1.0,  # output_scale
1874                0,  # output_zero_point
1875                dtype,  # output_dtype
1876                "none",  # attr
1877                [],  # scalars
1878                "",  # algorithm
1879            )
1880            new_conv_node = graph.call_function(
1881                torch.ops.onednn.qconv2d_pointwise.default, args=new_args
1882            )
1883            conv_node.replace_all_uses_with(new_conv_node)
1884            new_conv_node.meta.update(conv_node.meta)
1885
1886            # Erase the original conv node
1887            graph.erase_node(conv_node)
1888            # Erase the dequant pattern
1889            if dtype == torch.bfloat16:
1890                graph.erase_node(convert_to_bf16)  # type: ignore[possibly-undefined, arg-type]
1891            graph.erase_node(dequant_node)  # type: ignore[arg-type]
1892            # Erase the dequant per channel pattern
1893            if clone_node is not None:
1894                graph.erase_node(clone_node)  # type: ignore[arg-type]
1895            if dtype == torch.bfloat16:
1896                graph.erase_node(weight_to_bf16_node)  # type: ignore[possibly-undefined, arg-type]
1897            graph.erase_node(dequant_per_channel)  # type: ignore[arg-type]
1898            counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
1899            counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
1900                match.nodes
1901            )
1902
1903
1904def _generate_dequant_convolution_node_pattern(
1905    _dequant_per_channel_pattern, dtype=torch.float32
1906):
1907    assert dtype in [torch.float32, torch.bfloat16]
1908    dequant_convolution_node_pattern = CallFunction(
1909        aten.convolution.default,
1910        _may_generate_pattern_with_dtype_convert(
1911            get_dequantize_per_tensor_activation_pattern(),
1912            KeywordArg("autocast_act_dtype"),
1913            dtype == torch.bfloat16,
1914        ),
1915        _dequant_per_channel_pattern,
1916        KeywordArg("b"),
1917        KeywordArg("stride"),
1918        KeywordArg("padding"),
1919        KeywordArg("dilation"),
1920        KeywordArg("is_transposed"),
1921        KeywordArg("out_padding"),
1922        KeywordArg("groups"),
1923    )
1924    return dequant_convolution_node_pattern
1925
1926
1927def _generate_qconv_weight_prepack_patterns(dtype=torch.float32):
1928    assert dtype in [torch.float32, torch.bfloat16]
1929    return (
1930        _generate_dequant_convolution_node_pattern(
1931            dequantize_per_channel_weight_pattern
1932            if dtype == torch.float32
1933            else dequantize_per_channel_to_bf16_weight_pattern,
1934            dtype,
1935        ),
1936        # There is another pattern due to the pass of convert_conv_weights_to_channels_last
1937        # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
1938        # Depend on some heuristics, it may or may not insert to(channel_last) node
1939        # between convolution and dequant_per_channel node
1940        _generate_dequant_convolution_node_pattern(
1941            dequantize_per_channel_clone_weight_pattern
1942            if dtype == torch.float32
1943            else dequantize_per_channel_to_bf16_clone_weight_pattern,
1944            dtype,
1945        ),
1946    )
1947
1948
1949def _get_linear_node(match, input_dim_exceeds_two, input_contiguous):
1950    output_reshape_node = None
1951    if input_dim_exceeds_two:
1952        if input_contiguous:
1953            output_reshape_node = match.output_node()
1954            assert output_reshape_node.target is aten.reshape.default
1955            linear_node = output_reshape_node.args[0]
1956        else:
1957            linear_nodes = filter_nodes(match.nodes, aten.bmm.default)
1958            assert len(linear_nodes) == 1
1959            linear_node = linear_nodes[0]
1960    else:
1961        linear_node = match.output_node()
1962
1963    assert linear_node.target in (
1964        aten.addmm.default,
1965        aten.mm.default,
1966        aten.bmm.default,
1967    )
1968    return linear_node, output_reshape_node
1969
1970
1971def _get_linear_dq_node(
1972    linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
1973):
1974    act_reshape_node = None
1975    activation_to_bf16_node = None
1976    act_expand_node = None
1977    if input_dim_exceeds_two:
1978        if input_contiguous:
1979            act_reshape_node = linear_node.args[input_index]
1980            assert act_reshape_node.target is aten.reshape.default
1981            if dtype == torch.float32:
1982                # pattern: linear -> reshape -> dequant
1983                dequant_node = act_reshape_node.args[0]
1984            else:
1985                # pattern: linear -> reshape -> to_bf16 -> dequant
1986                activation_to_bf16_node = act_reshape_node.args[0]
1987                dequant_node = activation_to_bf16_node.args[0]
1988        else:
1989            # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous
1990            act_expand_node = linear_node.args[input_index]
1991            assert act_expand_node.target is aten.expand.default
1992            if dtype == torch.float32:
1993                dequant_node = act_expand_node.args[0]
1994            else:
1995                activation_to_bf16_node = act_expand_node.args[0]
1996                dequant_node = activation_to_bf16_node.args[0]
1997    else:
1998        if dtype == torch.float32:
1999            # pattern: linear -> dequant
2000            dequant_node = linear_node.args[input_index]
2001        else:
2002            # pattern: linear -> to_bf16 -> dequant
2003            activation_to_bf16_node = linear_node.args[input_index]
2004            dequant_node = activation_to_bf16_node.args[0]
2005    return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node
2006
2007
2008def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous):
2009    def _inner(match):
2010        # Check dequant pattern has only 1 user.
2011        (
2012            linear_node,
2013            _,
2014        ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
2015
2016        input_index = 1 if linear_node.target is aten.addmm.default else 0
2017        assert dtype in [torch.float32, torch.bfloat16]
2018        (
2019            dequant_node,
2020            _,
2021            _,
2022            _,
2023        ) = _get_linear_dq_node(
2024            linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
2025        )
2026
2027        assert dequant_node.target in [
2028            quantized_decomposed.dequantize_per_tensor.default,
2029            quantized_decomposed.dequantize_per_tensor.tensor,
2030        ]
2031
2032        if len(list(dequant_node.users)) != 1:
2033            # Ensure the dequant pattern only has 1 user
2034            # since we will delete the dequant pattern here
2035            return False
2036
2037        # Extra check for bmm pattern
2038        if input_dim_exceeds_two and not input_contiguous:
2039            # Check for act
2040            # Act expand size should be exactly same as act size
2041            act_expand_size = match.kwargs["act_expand_size"]
2042            act_node = match.kwargs["x"]
2043            if not (
2044                hasattr(act_node, "meta")
2045                and isinstance(act_node.meta.get("val", None), torch.Tensor)
2046                and (act_node.meta["val"].size() == torch.Size(act_expand_size))
2047            ):
2048                return False
2049
2050            # Check for wgt
2051            # wgt permute dims should be [1, 0]
2052            wgt_permute_dims = match.kwargs["permute_axes"]
2053            if wgt_permute_dims != [1, 0]:
2054                return False
2055
2056            # Check below wgt size items:
2057            # wgt before expand should with dim 2
2058            # Expand size should with dim 3
2059            # Expand size[0] should same as act size[0]
2060            # Expand size[1] should same as wgt size[1]
2061            # Expand size[2] should same as wgt size[0]
2062            qweight_node = match.kwargs["q_weight"]
2063            wgt_expand_size = match.kwargs["wgt_expand_size"]
2064            if not (
2065                hasattr(qweight_node, "meta")
2066                and isinstance(qweight_node.meta.get("val", None), torch.Tensor)
2067                and len(qweight_node.meta["val"].size()) == 2
2068                and len(wgt_expand_size) == 3
2069                and wgt_expand_size[0] == act_node.meta["val"].size()[0]
2070                and wgt_expand_size[1] == qweight_node.meta["val"].size()[1]
2071                and wgt_expand_size[2] == qweight_node.meta["val"].size()[0]
2072            ):
2073                return False
2074
2075        return True
2076
2077    return _inner
2078
2079
2080def _register_qlinear_weight_prepack_pass(
2081    pattern,
2082    pass_number,
2083    dtype=torch.float32,
2084    input_dim_exceeds_two=False,
2085    input_contiguous=True,
2086):
2087    @register_freezing_graph_pattern(
2088        pattern,
2089        extra_check=_is_valid_dequant_linear_pattern(
2090            dtype, input_dim_exceeds_two, input_contiguous
2091        ),
2092        pass_number=pass_number,
2093    )
2094    def qlinear_weight_prepack(match: Match, *args, **kwargs):
2095        """
2096        Match the pattern:
2097        int8 activation
2098          |
2099        dequant_per_tensor
2100          |
2101        mm/addmm <- t <- dequant_per_channel <- int8_weight
2102
2103        Insert weight prepack node and change the pattern to:
2104        int8 activation
2105          |
2106        onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight
2107        """
2108        assert dtype in [torch.float32, torch.bfloat16]
2109        (
2110            linear_node,
2111            output_reshape_node,
2112        ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
2113        input_index = 1 if linear_node.target is aten.addmm.default else 0
2114        weight_index = input_index + 1
2115
2116        (
2117            dequant_node,
2118            act_reshape_node,
2119            activation_to_bf16_node,
2120            act_expand_node,
2121        ) = _get_linear_dq_node(
2122            linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
2123        )
2124
2125        if input_dim_exceeds_two and not input_contiguous:
2126            wgt_expand_node = linear_node.args[weight_index]
2127            assert wgt_expand_node.target is aten.expand.default
2128            t_node = wgt_expand_node.args[0]
2129        else:
2130            t_node = linear_node.args[weight_index]
2131
2132        if dtype == torch.float32:
2133            dequant_per_channel = t_node.args[0]
2134        else:
2135            weight_to_bf16_node = t_node.args[0]
2136            dequant_per_channel = weight_to_bf16_node.args[0]
2137        assert (
2138            dequant_per_channel.target
2139            is quantized_decomposed.dequantize_per_channel.default
2140        )
2141
2142        # Activation QParams
2143        qx, x_zp, x_scale = (
2144            kwargs["x"],
2145            kwargs["x_zp"],
2146            kwargs["x_scale"],
2147        )
2148
2149        # Weight QParams
2150        qw, w_scale, w_zp = (
2151            kwargs["q_weight"],
2152            kwargs["w_scale"],
2153            kwargs["w_zp"],
2154        )
2155
2156        # Params
2157        bias = kwargs["b"] if "b" in kwargs else None
2158
2159        x_shape = qx.meta.get("tensor_meta").shape
2160        if has_free_symbols(x_shape):
2161            # For dynamic shape case, we can't get activation shape ahead of runtime.
2162            x_shape = None
2163        graph = match.graph
2164        with graph.inserting_before(linear_node):
2165            # Insert weight prepack node and the qlinear node
2166            packed_weight_inputs = (
2167                qw,
2168                x_shape,
2169            )
2170            packed_weight_op = torch.ops.onednn.qlinear_prepack
2171            prepack_weight_node = graph.call_function(
2172                packed_weight_op, args=packed_weight_inputs
2173            )
2174
2175            new_args: Tuple[Any, ...] = (
2176                qx,
2177                x_scale,
2178                x_zp,
2179                prepack_weight_node,
2180                w_scale,
2181                w_zp,
2182                bias,
2183                1.0,  # output_scale
2184                0,  # output_zero_point
2185                dtype,  # output_dtype
2186                "none",  # post op name
2187                [],  # post op args
2188                "",  # post op algorithm
2189            )
2190            Node = torch.fx.node.Node
2191            if isinstance(x_scale, Node) and isinstance(x_zp, Node):
2192                new_linear_node = graph.call_function(
2193                    torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
2194                )
2195            else:
2196                new_linear_node = graph.call_function(
2197                    torch.ops.onednn.qlinear_pointwise.default, args=new_args
2198                )
2199            if input_dim_exceeds_two:
2200                if input_contiguous:
2201                    output_reshape_node.replace_all_uses_with(new_linear_node)
2202                    new_linear_node.meta.update(output_reshape_node.meta)
2203                else:
2204                    if bias:
2205                        output_add_node_for_bias = match.output_node()
2206                        assert output_add_node_for_bias.target is aten.add.Tensor
2207                        output_add_node_for_bias.replace_all_uses_with(new_linear_node)
2208                        new_linear_node.meta.update(output_add_node_for_bias.meta)
2209                    else:
2210                        linear_node.replace_all_uses_with(new_linear_node)
2211                        new_linear_node.meta.update(linear_node.meta)
2212            else:
2213                linear_node.replace_all_uses_with(new_linear_node)
2214                new_linear_node.meta.update(linear_node.meta)
2215
2216            # Erase the original linear node
2217            if input_dim_exceeds_two:
2218                if input_contiguous:
2219                    graph.erase_node(output_reshape_node)
2220                elif not input_contiguous and bias:
2221                    graph.erase_node(output_add_node_for_bias)  # type: ignore[possibly-undefined]
2222            graph.erase_node(linear_node)
2223            if input_dim_exceeds_two:
2224                if input_contiguous:
2225                    graph.erase_node(act_reshape_node)
2226                else:
2227                    graph.erase_node(act_expand_node)
2228                    graph.erase_node(wgt_expand_node)  # type: ignore[possibly-undefined]
2229            if dtype == torch.bfloat16:
2230                graph.erase_node(activation_to_bf16_node)
2231            # Erase the dequant pattern
2232            graph.erase_node(dequant_node)
2233            # Erase the dequant per channel pattern
2234            graph.erase_node(t_node)
2235            if dtype == torch.bfloat16:
2236                graph.erase_node(weight_to_bf16_node)  # type: ignore[possibly-undefined]
2237            graph.erase_node(dequant_per_channel)
2238
2239            counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
2240            counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
2241                match.nodes
2242            )
2243
2244
2245def _generate_dequant_linear_node_pattern(
2246    _dequant_per_channel_pattern,
2247    dtype=torch.float32,
2248    input_dim_exceeds_two=False,
2249    is_tensor_overload=False,
2250):
2251    assert dtype in [torch.float32, torch.bfloat16]
2252    t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
2253    dequant_linear_bias_pattern = _may_generate_pattern_with_reshape(
2254        CallFunction(
2255            aten.addmm.default,
2256            KeywordArg("b"),
2257            _may_generate_pattern_with_reshape(
2258                _may_generate_pattern_with_dtype_convert(
2259                    get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
2260                    KeywordArg("autocast_act_dtype"),
2261                    dtype == torch.bfloat16,
2262                ),
2263                KeywordArg("act_reshape_size"),
2264                input_dim_exceeds_two,
2265            ),
2266            t_pattern,
2267        ),
2268        KeywordArg("output_reshape_size"),
2269        input_dim_exceeds_two,
2270    )
2271    dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape(
2272        CallFunction(
2273            aten.mm.default,
2274            _may_generate_pattern_with_reshape(
2275                _may_generate_pattern_with_dtype_convert(
2276                    get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
2277                    KeywordArg("autocast_act_dtype"),
2278                    dtype == torch.bfloat16,
2279                ),
2280                KeywordArg("act_reshape_size"),
2281                input_dim_exceeds_two,
2282            ),
2283            t_pattern,
2284        ),
2285        KeywordArg("output_reshape_size"),
2286        input_dim_exceeds_two,
2287    )
2288    return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
2289
2290
2291def _generate_dequant_bmm_node_pattern(
2292    _dequant_per_channel_pattern,
2293    dtype=torch.float32,
2294    with_bias=False,
2295    is_tensor_overload=False,
2296):
2297    # When activation of linear dim exceed 2 and not contiguous
2298    t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
2299
2300    assert dtype in [torch.float32, torch.bfloat16]
2301    dequant_bmm_pattern = CallFunction(
2302        aten.bmm.default,
2303        CallFunction(
2304            aten.expand.default,
2305            _may_generate_pattern_with_dtype_convert(
2306                get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
2307                KeywordArg("autocast_act_dtype"),
2308                dtype == torch.bfloat16,
2309            ),
2310            KeywordArg("act_expand_size"),
2311        ),
2312        CallFunction(
2313            aten.expand.default,
2314            t_pattern,
2315            KeywordArg("wgt_expand_size"),
2316        ),
2317    )
2318
2319    def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias):
2320        if _with_bias:
2321            return CallFunction(
2322                aten.add.Tensor,
2323                _dequant_bmm_pattern,
2324                KeywordArg("b"),
2325            )
2326        else:
2327            return _dequant_bmm_pattern
2328
2329    return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias)
2330
2331
2332def _generate_qlinear_weight_prepack_patterns(
2333    dtype=torch.float32,
2334    input_dim_exceeds_two=False,
2335    input_contiguous=True,
2336    with_bias=False,
2337    is_tensor_overload=False,
2338):
2339    if input_dim_exceeds_two and not input_contiguous:
2340        return _generate_dequant_bmm_node_pattern(
2341            dequantize_per_channel_weight_pattern,
2342            dtype,
2343            with_bias,
2344            is_tensor_overload,
2345        )
2346    else:
2347        return _generate_dequant_linear_node_pattern(
2348            dequantize_per_channel_weight_pattern,
2349            dtype,
2350            input_dim_exceeds_two,
2351            is_tensor_overload,
2352        )
2353
2354
2355def _register_dequant_promotion():
2356    dequant_pattern_cases = itertools.product(
2357        [torch.float32, torch.bfloat16], [True, False], [True, False]
2358    )
2359    for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases:
2360        # 4 dequantization patterns will be matched based on the dtype and input dimension size.
2361        # Case 1: int8-mixed-fp32, input dim size is 2
2362        # Case 2: int8-mixed-fp32, input dim size exceeds 2
2363        # Case 3: int8-mixed-bf16, input dim size is 2
2364        # Case 4: int8-mixed-bf16, input dim size exceeds 2
2365        #           quant
2366        #   + - - - - | - - - - +
2367        #   |      dequant      |
2368        #   |         |         |
2369        #   |    OPT(to_bf16)   |
2370        #   |         |         |
2371        #   |    OPT(reshape)   |
2372        #   |      /     \      |
2373        #   |    node1  node2   |
2374        #   + - - | - - - | - - +
2375        #  OPT(reshape) OPT(reshape)
2376        #   + - - | - - - | - - +
2377        #  OPT(to_fp32) OPT(to_fp32)
2378        #   + - - | - - - | - - +
2379        #       quant   quant
2380        _register_dequant_promotion_pass(
2381            _may_generate_pattern_with_reshape(
2382                _may_generate_pattern_with_dtype_convert(
2383                    get_dequantize_per_tensor_activation_pattern(
2384                        is_tensor_overload=is_tensor_overload
2385                    ),
2386                    KeywordArg("autocast_act_dtype"),
2387                    dtype == torch.bfloat16,
2388                ),
2389                KeywordArg("act_reshape_size"),
2390                with_reshape=input_dim_exceeds_two,
2391            ),
2392            pass_number=0,
2393            dtype=dtype,
2394        )  # pass_number=0 to run before weight prepack
2395
2396
2397def _register_qconv_weight_prepack():
2398    for dtype in [torch.float32, torch.bfloat16]:
2399        weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype)
2400        for weight_prepack_pattern in weight_prepack_patterns:
2401            # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
2402            _register_qconv_weight_prepack_pass(
2403                weight_prepack_pattern, pass_number=1, dtype=dtype
2404            )
2405
2406
2407def _register_qlinear_weight_prepack():
2408    # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous.
2409    # Then convert the pattern into a QLinear node with int8_fp32/bf16.
2410    # Case 1: int8-mixed-fp32, input dim size is 2
2411    # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous
2412    # Case 3: int8-mixed-bf16, input dim size is 2
2413    # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous
2414
2415    #   + - - - - | - - - - - - | - - - - - +
2416    #   |    dq_per_tensor  dq_per_channel  |
2417    #   |         |              |          |
2418    #   |    OPT(to_bf16)    OPT(to_bf16)   |
2419    #   |         |              |          |
2420    #   |     OPT(reshape)   permute        |
2421    #   |            \        /             |
2422    #   |             addmm/mm              |
2423    #   |                |                  |
2424    #   |           OPT(reshape)            |
2425
2426    # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous
2427    # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous
2428
2429    #   + - - - - | - - - - - - | - - - - - +
2430    #   |    dq_per_tensor  dq_per_channel  |
2431    #   |         |              |          |
2432    #   |    OPT(to_bf16)    OPT(to_bf16)   |
2433    #   |         |              |          |
2434    #   |       expand       permute        |
2435    #   |          \             |          |
2436    #   |                    expand         |
2437    #   |                    /              |
2438    #   |               bmm                 |
2439    #   |                |                  |
2440    #   |            OPT(add)               |
2441
2442    linear_weight_prepack_cases = itertools.product(
2443        [torch.float32, torch.bfloat16], [True, False], [True, False]
2444    )
2445
2446    # Step 1: register patterns from mm and addmm
2447    for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases:
2448        weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns(
2449            dtype,
2450            input_dim_exceeds_two,
2451            is_tensor_overload=is_tensor_overload,
2452        )
2453        for weight_prepack_pattern in weight_prepack_patterns:
2454            # Register to pass_number 1, so we can do dequant promotion in pass_number 0.
2455            _register_qlinear_weight_prepack_pass(
2456                weight_prepack_pattern,
2457                pass_number=1,
2458                dtype=dtype,
2459                input_dim_exceeds_two=input_dim_exceeds_two,
2460            )
2461
2462    # Step 2: register patterns from bmm
2463    # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous
2464    # refer to:
2465    # https://github.com/pytorch/pytorch/blob/
2466    # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968
2467    # in this case, we can convert it back to qlinear
2468    for dtype, with_bias, is_tensor_overload in itertools.product(
2469        [torch.float32, torch.bfloat16], [True, False], [True, False]
2470    ):
2471        bmm_pattern = _generate_qlinear_weight_prepack_patterns(
2472            dtype=dtype,
2473            input_dim_exceeds_two=True,
2474            input_contiguous=False,
2475            with_bias=with_bias,
2476            is_tensor_overload=is_tensor_overload,
2477        )
2478        _register_qlinear_weight_prepack_pass(
2479            bmm_pattern,
2480            pass_number=1
2481            if with_bias
2482            else 2,  # if with_bias, there is an output add, so we should try to match it firstly
2483            dtype=dtype,
2484            input_dim_exceeds_two=True,
2485            input_contiguous=False,
2486        )
2487
2488
2489@functools.lru_cache(None)
2490def _register_quantization_weight_pack_pass():
2491    # Step 1: Dequant promotion for int8-mixed-fp32/bf16
2492    _register_dequant_promotion()
2493
2494    # Step 2: QConv weight prepack
2495    _register_qconv_weight_prepack()
2496
2497    # Step 3: QLinear weight prepack
2498    _register_qlinear_weight_prepack()
2499
2500
2501def quant_lift_up(graph_module: torch.fx.GraphModule):
2502    """
2503    Lift up the quant node before view like nodes. It can benefit performance
2504    of Attention like block. For example, we have the pattern as:
2505
2506             DQ
2507    DQ       LINEAR
2508    LINEAR   VIEW
2509    VIEW     PERMUTE
2510    PERMUTE  TRANSPOSE
2511    Q        Q
2512    DQ       DQ
2513       Matmul
2514        DIV
2515        ADD
2516      SOFTMAX
2517
2518    We want to lift up the the quant nodes from matmul before view like nodes
2519    as the output of Linear node.
2520
2521             DQ
2522    DQ       LINEAR
2523    LINEAR   Q
2524    Q        VIEW
2525    VIEW     PERMUTE
2526    PERMUTE  TRANSPOSE
2527    DQ       DQ
2528       Matmul
2529        DIV
2530        ADD
2531      SOFTMAX
2532
2533    It produces a DQ->LINEAR->Q pattern which can be fused by backend.
2534    """
2535
2536    def is_view_op(node):
2537        return node.op == "call_function" and node.target in _VIEW_OPS
2538
2539    for node in graph_module.graph.nodes:
2540        # <TODO> Leslie: Here we verify that the quant node has exactly
2541        # one input FX node, with constant scalar value for scale and zero point.
2542        # For the case input of quant node has more than one input FX nodes,
2543        # extend the implementation to lift up all the connected nodes
2544        # before the view nodes to keep the topological order.
2545        if (
2546            node.op == "call_function"
2547            and node.target in _PER_TENSOR_QUANTIZE_OPS
2548            and len(node.all_input_nodes) == 1
2549            and is_view_op(node.all_input_nodes[0])
2550        ):
2551            quant_node = node
2552            input_node_of_quant = quant_node.args[0]
2553
2554            # Check the nodes along lift up path has only 1 user node
2555            # Propagate view like node to find where to insert the new quant node
2556            could_lift_up = True
2557            current_node = quant_node
2558            input_node = current_node.args[0]
2559            while is_view_op(input_node):
2560                if len(input_node.users) != 1:
2561                    could_lift_up = False
2562                    break
2563                current_node = input_node
2564                input_node = current_node.args[0]
2565
2566            # Further check the input node of the first view node has only 1 user node
2567            if could_lift_up and len(input_node.users) == 1:
2568                # Replace dequant's input from quant to quant's input
2569                quant_node.replace_all_uses_with(input_node_of_quant)
2570                # Insert the new quant node
2571                with graph_module.graph.inserting_before(current_node):
2572                    new_quant_node = graph_module.graph.node_copy(quant_node)
2573                    input_node.replace_all_uses_with(new_quant_node)
2574
2575                    # Update inputs of new_quant_node
2576                    def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
2577                        if n == input_node_of_quant:
2578                            return input_node
2579                        else:
2580                            return n
2581
2582                    new_args = map_arg(new_quant_node.args, maybe_replace_node)
2583                    new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
2584                    new_quant_node.args = new_args  # type: ignore[assignment]
2585                    new_quant_node.kwargs = new_kwargs  # type: ignore[assignment]
2586                    graph_module.graph.erase_node(quant_node)
2587
2588    graph_module.graph.lint()
2589    graph_module.recompile()
2590