xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantizer/x86_inductor_quantizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import functools
4import itertools
5import operator
6import warnings
7from dataclasses import dataclass
8from typing import (
9    Any,
10    Callable,
11    Dict,
12    List,
13    Optional,
14    Sequence,
15    Set,
16    Tuple,
17    TYPE_CHECKING,
18    Union,
19)
20from typing_extensions import TypeAlias
21
22import torch
23import torch.nn.functional as F
24from torch.ao.quantization.fake_quantize import (
25    FakeQuantize,
26    FusedMovingAvgObsFakeQuantize,
27)
28from torch.ao.quantization.observer import (
29    HistogramObserver,
30    MovingAverageMinMaxObserver,
31    MovingAveragePerChannelMinMaxObserver,
32    PerChannelMinMaxObserver,
33    PlaceholderObserver,
34)
35from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
36from torch.ao.quantization.quantizer.quantizer import (
37    QuantizationAnnotation,
38    QuantizationSpec,
39    Quantizer,
40    SharedQuantizationSpec,
41)
42from torch.ao.quantization.quantizer.utils import _get_module_name_filter
43from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
44    get_bias_qspec,
45    get_input_act_qspec,
46    get_output_act_qspec,
47    get_weight_qspec,
48    OperatorConfig,
49    OperatorPatternType,
50    QuantizationConfig,
51)
52from torch.fx import Node
53from torch.fx.passes.utils.source_matcher_utils import (
54    get_source_partitions,
55    SourcePartition,
56)
57
58
59FilterFn: TypeAlias = Callable[[List[Node]], bool]
60
61
62if TYPE_CHECKING:
63    from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
64
65__all__ = [
66    "X86InductorQuantizer",
67    "get_default_x86_inductor_quantization_config",
68]
69
70
71@dataclass
72class _X86InductorQuantizationAnnotation(QuantizationAnnotation):
73    # _is_output_of_quantized_pattern:
74    #  * Node as output node of a fusion pattern.
75    #  * The fusion pattern supports int8 data type.
76    #  * The fusion pattern has inputs annotated to insert observer.
77    #  * The quantization_config is not `None`.
78    _is_output_of_quantized_pattern: bool = False
79
80
81# Operators that:
82# 1. Operators are optimized to run with int8 when int8 input provided.
83# 2. Operators do not support int8 input and produce fp32 output.
84int8_in_int8_out_ops: Set = {
85    torch.ops.aten.max_pool2d.default,
86    torch.ops.aten.cat.default,
87    torch.ops.aten.avg_pool2d.default,
88    torch.ops.aten.adaptive_avg_pool2d.default,
89    torch.ops.aten.flatten.using_ints,
90}
91
92# Operators that support the int8 data type for quantization config propagation.
93# A superset of int8_in_int8_out_ops incorporating additional operators.
94propagation_quantizable_ops = int8_in_int8_out_ops
95
96# Operators support the int8 data type
97# and recipe is configured by default in X86InductorQuantizer.
98default_quantizable_ops = propagation_quantizable_ops | {
99    torch.ops.aten.conv2d.default,
100    torch.ops.aten.linear.default,
101}
102
103# A superset of default_quantizable_ops includes operators support the int8 data type
104# but not enabled by default recipe of X86InductorQuantizer.
105quantizable_ops = default_quantizable_ops | {
106    torch.ops.aten.matmul.default,
107}
108
109QUANT_ANNOTATION_KEY = "quantization_annotation"
110
111
112def _skip_annotate(nodes: List[Node], filter_fn: Optional[FilterFn] = None) -> bool:
113    """Determine whether to skip annotation for a list of nodes."""
114
115    # 1) Skip annotate if any node is already annotated
116    if _is_any_annotated(nodes):
117        return True
118
119    # 2) Proceed annotate if a) a filter function is provided
120    # and b) the given nodes list passes the filter function check.
121    if filter_fn and filter_fn(nodes):
122        return False
123
124    return True
125
126
127def _create_module_name_filter(module_name: str) -> FilterFn:
128    """Create a filter function for a given module name.
129
130    The filter function takes a list of nodes (as determined by the annotate function)
131    and return True if *all* nodes come from the specified module name, False otherwise.
132
133    For example:
134        linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1`
135        relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1`
136
137    >> module_name_filter = _create_module_name_filter_inner("sub")
138    >> print(module_name_filter([relu, linear_1]))
139    # True  # These two nodes are determined by `_annotate_linear_unary` function and from "sub".
140    """
141
142    filter_fn = _get_module_name_filter(module_name)
143
144    def check_all_nodes_from_module(nodes: List[Node]) -> bool:
145        all_nodes_from_module_name: bool = all(filter_fn(n) for n in nodes)
146        return all_nodes_from_module_name
147
148    return check_all_nodes_from_module
149
150
151def _create_operator_type_filter(
152    operator_type: Callable,
153) -> FilterFn:
154    """Create a filter function for a given operator type.
155
156    The filter function takes a list of nodes and returns True if it contains
157    exactly one node with the specified operator type, False otherwise.
158
159    For example:
160        linear_1: "f32[3, 10]" = torch.ops.aten.linear.default(...) # comes from a module with name `sub.linear1`
161        relu: "f32[3, 10]" = torch.ops.aten.relu.default(linear_1); # comes from a module with name `sub.relu1`
162
163    >> operator_type_filter = _create_operator_type_filter(torch.ops.aten.linear.default)
164    >> print(operator_type_filter([relu, linear_1]))
165    # True  # These two nodes are determined by `_annotate_linear_unary` function and the second node is `linear`.
166    """
167
168    def operator_type_filter(nodes: List[Node]):
169        num_nodes_with_operator_type = sum(
170            node.target == operator_type for node in nodes
171        )
172        if num_nodes_with_operator_type > 1:
173            raise NotImplementedError(
174                f"Several nodes within a single pattern are {operator_type}."
175            )
176        return num_nodes_with_operator_type == 1
177
178    return operator_type_filter
179
180
181def _global_config_filter(nodes: List[Node]) -> bool:
182    """Filter function for global configuration.
183
184    This filter function takes a list of nodes and returns True if there is exactly one node
185    in the list that is a default quantizable operation, False otherwise.
186    """
187    num_nodes_in_default_quantizable_ops = sum(
188        node.target in default_quantizable_ops for node in nodes
189    )
190    if num_nodes_in_default_quantizable_ops > 1:
191        raise NotImplementedError(
192            "Several nodes within a single pattern are default quantizable operations."
193        )
194    return num_nodes_in_default_quantizable_ops == 1
195
196
197def _map_module_function_to_aten_operator_type():
198    module_function_to_aten_operator: Dict[Callable, torch._ops.OpOverloadPacket] = {}
199    map_list = (
200        ([torch.nn.Conv2d, F.conv2d], torch.ops.aten.conv2d.default),
201        ([torch.nn.Linear, F.linear], torch.ops.aten.linear.default),
202        ([torch.nn.MaxPool2d, F.max_pool2d], torch.ops.aten.max_pool2d.default),
203        (
204            [
205                torch.cat,
206            ],
207            torch.ops.aten.cat.default,
208        ),
209        ([torch.nn.AvgPool2d, F.avg_pool2d], torch.ops.aten.avg_pool2d.default),
210        (
211            [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d],
212            torch.ops.aten.adaptive_avg_pool2d.default,
213        ),
214        (
215            [
216                torch.flatten,
217            ],
218            torch.ops.aten.flatten.using_ints,
219        ),
220        (
221            [
222                torch.matmul,
223            ],
224            torch.ops.aten.matmul.default,
225        ),
226    )
227    for map_item in map_list:
228        module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1]))  # type: ignore[call-overload]
229    return module_function_to_aten_operator
230
231
232def _mark_nodes_as_annotated(nodes: List[Node]):
233    for node in nodes:
234        if node is not None:
235            if QUANT_ANNOTATION_KEY not in node.meta:
236                node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation()
237            node.meta[QUANT_ANNOTATION_KEY]._annotated = True
238
239
240def _is_node_annotated(_node):
241    """
242    return True if the node is annotated, otherwise return False
243    """
244    return (
245        QUANT_ANNOTATION_KEY in _node.meta
246        and _node.meta[QUANT_ANNOTATION_KEY]._annotated
247    )
248
249
250def _is_any_annotated(nodes: List[Node]):
251    """
252    Given a list of nodes (that represents an operator pattern),
253    check if any of the node is annotated, return True if any of the node
254    is annotated, otherwise return False.
255    """
256    return any(_is_node_annotated(node) for node in nodes)
257
258
259def _is_all_annotated(nodes: List[Node]):
260    """
261    Given a list of nodes (that represents an operator pattern),
262    return True if all of the node is annotated, otherwise return False.
263    """
264    return all(_is_node_annotated(node) for node in nodes)
265
266
267def _is_quantized_op_pt2e(node: torch.fx.Node):
268    """
269    Used for pt2e flow to check if the node is a quantized node:
270    Case1: the node has been annotated as output node of a fusion pattern.
271    Case2: the node has been annotated as single quantized node.
272    """
273    if not _is_any_annotated([node]):
274        # The node has not been annotated, directly return False
275        return False
276    quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None)
277    assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation)
278    return quantization_annotation._is_output_of_quantized_pattern
279
280
281def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
282    # TODO: Add more supported operators here.
283    supported_operators: Dict[str, List[OperatorPatternType]] = {
284        "conv2d": [
285            [torch.nn.Conv2d],
286            [F.conv2d],
287        ],
288    }
289
290    # Append Conv Optional(Add) Optioinal(ReLU)
291    conv_add_relu_options = itertools.product(
292        [torch.nn.Conv2d, F.conv2d],
293        [torch.add, operator.add, None],  # add
294        [torch.nn.ReLU, F.relu, None],  # relu
295    )
296    for conv_op, add_op, relu_op in conv_add_relu_options:
297        if add_op is None:
298            # Append Conv ReLU
299            supported_operators["conv2d"].append([conv_op, relu_op])  # type: ignore[list-item]
300        elif relu_op is None:
301            # Append Conv Add
302            supported_operators["conv2d"].append([conv_op, add_op])  # type: ignore[list-item]
303        else:
304            # Append Conv Add ReLU
305            supported_operators["conv2d"].append([conv_op, add_op, relu_op])  # type: ignore[list-item]
306
307    return copy.deepcopy(supported_operators)
308
309
310def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
311    supported_config_and_operators: List[OperatorConfig] = []
312    for quantization_config in [
313        get_default_x86_inductor_quantization_config(),
314    ]:
315        ops = _supported_quantized_operators()
316        for pattern_list in ops.values():
317            supported_config_and_operators.append(
318                OperatorConfig(quantization_config, pattern_list)
319            )
320    return copy.deepcopy(supported_config_and_operators)
321
322
323@functools.lru_cache
324def get_default_x86_inductor_quantization_config(
325    is_qat: bool = False,
326    is_dynamic: bool = False,
327):
328    extra_args: Dict[str, Any] = {"eps": 2**-12}
329    if is_qat:
330        if is_dynamic:
331            act_observer_or_fake_quant_ctr = FakeQuantize
332            dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
333                averaging_constant=1
334            )
335            extra_args["observer"] = dynamic_quant_observer
336        else:
337            act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize  # type: ignore[assignment]
338    else:
339        if is_dynamic:
340            act_observer_or_fake_quant_ctr = PlaceholderObserver  # type: ignore[assignment]
341        else:
342            act_observer_or_fake_quant_ctr = HistogramObserver  # type: ignore[assignment]
343
344    # Copy from x86 default qconfig from torch/ao/quantization/qconfig.py
345    act_quantization_spec = QuantizationSpec(
346        dtype=torch.uint8,
347        quant_min=0,
348        quant_max=255,  # reduce_range=False
349        qscheme=torch.per_tensor_affine,
350        is_dynamic=is_dynamic,
351        observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
352            **extra_args
353        ),
354    )
355
356    weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
357        FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver
358    )
359
360    if is_qat:
361        # Only support per channel quant for now
362        extra_args["observer"] = MovingAveragePerChannelMinMaxObserver  # type: ignore[dict-item]
363    weight_quantization_spec = QuantizationSpec(
364        dtype=torch.int8,
365        quant_min=-128,
366        quant_max=127,
367        qscheme=torch.per_channel_symmetric,
368        ch_axis=0,  # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
369        is_dynamic=False,
370        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
371            **extra_args
372        ),
373    )
374    bias_quantization_spec = None  # will use placeholder observer by default
375    quantization_config = QuantizationConfig(
376        act_quantization_spec,
377        act_quantization_spec,
378        weight_quantization_spec,
379        bias_quantization_spec,
380        is_qat,
381    )
382    return quantization_config
383
384
385def _get_supported_config_and_operators() -> List[OperatorConfig]:
386    return _get_supported_x86_inductor_config_and_operators()
387
388
389def _annotate_nodes_not_quantize(nodes: Union[Node, List[Node]]) -> None:
390    """Annotate nodes to exclude them from quantization (their `quantization_config` is `None`)."""
391    if not isinstance(nodes, list):
392        nodes = [nodes]
393    for node in nodes:
394        node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
395            _annotated=True
396        )
397
398
399def _config_checker(method: Callable) -> Callable:
400    @functools.wraps(method)
401    def wrapper(
402        quantizer: "X86InductorQuantizer",
403        name: Any,
404        quantization_config: Optional["QuantizationConfig"],
405    ) -> "X86InductorQuantizer":
406        if quantizer._need_skip_config(quantization_config):
407            warnings.warn(
408                f"Skip the quantization config for {name}.",
409            )
410            return quantizer
411        return method(quantizer, name, quantization_config)
412
413    return wrapper
414
415
416@dataclass
417class _CurrentQuantizationMode:
418    r"""Configuration defining the current quantization mode for the quantizer.
419
420    All possible current quantization modes are listed below:
421    ----------------------------------------------------------------------------------------------------------
422                |                                       dynamic_state
423     qat_state  |---------------------------------------------------------------------------------------------
424                |                           None                              |    True       |  False
425    ----------------------------------------------------------------------------------------------------------
426        None    | quantizer does not receive a non-None `quantization_config` | \             | \
427        False   | quantizer will not do QAT                                   | dynamic       | static
428        True    | quantizer will do QAT                                       | QAT + dynamic | QAT + static
429    """
430
431    qat_state: Optional[bool]
432    dynamic_state: Optional[bool]
433
434
435class X86InductorQuantizer(Quantizer):
436    supported_config_and_operators = _get_supported_config_and_operators()
437    module_function_to_aten_operator_type = _map_module_function_to_aten_operator_type()
438
439    def __init__(self) -> None:
440        super().__init__()
441        self.global_config: Optional[QuantizationConfig] = None
442        self.operator_type_qconfig: Dict[
443            torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
444        ] = {}
445        self.module_name_qconfig: Dict[str, Optional[QuantizationConfig]] = {}
446
447    @classmethod
448    def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
449        op_configs: Set[QuantizationConfig] = {
450            spec for spec, _ in cls.supported_config_and_operators
451        }
452        return list(op_configs)
453
454    @classmethod
455    def get_supported_operator_for_quantization_config(
456        cls, quantization_config: Optional[QuantizationConfig]
457    ) -> List[OperatorPatternType]:
458        if quantization_config is None:
459            all_ops = []
460            for _, ops in cls.supported_config_and_operators:
461                all_ops.extend(ops)
462            return all_ops
463
464        for config, ops in cls.supported_config_and_operators:
465            if config == quantization_config:
466                return ops
467        return []
468
469    def _get_current_quantization_mode(self) -> _CurrentQuantizationMode:
470        """Retrieves the current quantization mode based on all configurations."""
471        qat_state = None
472        dynamic_state = None
473
474        # As we use `_need_skip_config` to skip all invalid configurations,
475        # we can safely assume that the all existing non-None configurations
476        # have the same quantization mode.
477        for qconfig in (
478            list(self.module_name_qconfig.values())
479            + list(self.operator_type_qconfig.values())
480            + [self.global_config]
481        ):
482            if qconfig is not None:
483                # Query the `is_qat` state
484                if qat_state is None:
485                    qat_state = qconfig.is_qat
486                else:
487                    assert qat_state == qconfig.is_qat, (
488                        f"All non-None quantization configs should have the same `is_qat`,"
489                        f"but got {qat_state} and {qconfig.is_qat}."
490                    )
491                # Query the `is_dynamic` state
492                input_activation_spec = qconfig.input_activation
493                if input_activation_spec is not None:
494                    if dynamic_state is None:
495                        dynamic_state = input_activation_spec.is_dynamic
496                    else:
497                        assert dynamic_state == input_activation_spec.is_dynamic, (
498                            f"All non-None `input_activation_spec` should have the same `is_dynamic`,"
499                            f"but got {dynamic_state} and {input_activation_spec.is_dynamic}."
500                        )
501        return _CurrentQuantizationMode(
502            qat_state=qat_state, dynamic_state=dynamic_state
503        )
504
505    def _need_skip_config(
506        self, quantization_config: Optional[QuantizationConfig]
507    ) -> bool:
508        """Check if the provided quantization config is valid for X86InductorQuantizer.
509
510        Mixed static/dynamic configurations or mixed QAT/non-QAT configurations are not supported.
511        To avoid such a mix, we compare the incoming configuration with current configuration status.
512        Refer the `_CurrentQuantizationMode` definition for all possible modes.
513        """
514        if quantization_config is None:
515            return False
516
517        need_skip = False
518        current_mode = self._get_current_quantization_mode()
519        if (
520            current_mode.qat_state is not None
521            and current_mode.qat_state != quantization_config.is_qat
522        ):
523            warnings.warn("Mixed QAT and Non-QAT quantization config is not supported.")
524            need_skip = True
525        if current_mode.dynamic_state is not None:
526            input_activation_spec = quantization_config.input_activation
527            if (
528                input_activation_spec is not None
529                and current_mode.dynamic_state != input_activation_spec.is_dynamic
530            ):
531                warnings.warn(
532                    "Mixed dynamic and static quantization config is not supported."
533                )
534                need_skip = True
535        return need_skip
536
537    def set_global(self, quantization_config: QuantizationConfig):
538        if self._need_skip_config(quantization_config):
539            warnings.warn("Skip the global quantization config.")
540            return self
541        self.global_config = quantization_config
542        return self
543
544    def get_global_quantization_config(self):
545        if not isinstance(self.global_config, QuantizationConfig):
546            warnings.warn(
547                "The global_config for X86InductorQuantizer is currently invalid. \
548                Please ensure that you use set_global to establish the global quantization configuration."
549            )
550        return self.global_config
551
552    @_config_checker
553    def set_function_type_qconfig(
554        self,
555        function_type: Callable,
556        quantization_config: Optional[QuantizationConfig],
557    ) -> "X86InductorQuantizer":
558        if function_type in X86InductorQuantizer.module_function_to_aten_operator_type:
559            self._set_aten_operator_qconfig(
560                X86InductorQuantizer.module_function_to_aten_operator_type[
561                    function_type
562                ],
563                quantization_config,
564            )
565        else:
566            warnings.warn(
567                f"function: Unable to customize quantization config for {function_type} by X86InductorQuantizer."
568            )
569        return self
570
571    @_config_checker
572    def set_module_type_qconfig(
573        self,
574        module_type: torch.nn.Module,
575        quantization_config: Optional[QuantizationConfig],
576    ) -> "X86InductorQuantizer":
577        if module_type in X86InductorQuantizer.module_function_to_aten_operator_type:
578            self._set_aten_operator_qconfig(
579                X86InductorQuantizer.module_function_to_aten_operator_type[module_type],
580                quantization_config,
581            )
582        else:
583            warnings.warn(
584                f"Module: Unable to customize quantization config for {module_type} by X86InductorQuantizer."
585            )
586        return self
587
588    @_config_checker
589    def set_module_name_qconfig(
590        self, module_name: str, quantization_config: Optional[QuantizationConfig]
591    ):
592        """Set quantization_config for a submodule with name: `module_name`, for example:
593        quantizer.set_module_name_qconfig("blocks.sub"), it will quantize all supported operator/operator
594        patterns in the submodule with this module name with the given `quantization_config`
595
596        The supported operators include `quantizable_ops` and `propagation_quantizable_ops`.
597        """
598        self.module_name_qconfig[module_name] = quantization_config
599        return self
600
601    def _set_aten_operator_qconfig(
602        self,
603        operator_type: torch._ops.OpOverloadPacket,
604        quantization_config: Optional[QuantizationConfig],
605    ) -> "X86InductorQuantizer":
606        if operator_type in quantizable_ops:
607            self.operator_type_qconfig[operator_type] = quantization_config
608        else:
609            warnings.warn(
610                f"operator: Unable to quantize {operator} by X86InductorQuantizer."
611            )
612        return self
613
614    def _annotate_conv_node_helper(
615        self,
616        conv_node: torch.fx.Node,
617        annotate_output: bool,
618        quantization_config: Optional[QuantizationConfig],
619    ) -> None:
620        """Helper function to annotate the conv node"""
621        if quantization_config is None:
622            _annotate_nodes_not_quantize(conv_node)
623            return
624        input_qspec_map = {}
625        input_node = conv_node.args[0]
626        assert isinstance(input_node, Node)
627        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
628        weight_node = conv_node.args[1]
629        assert isinstance(weight_node, Node)
630        input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
631        bias_node = None if len(conv_node.args) == 2 else conv_node.args[2]
632        if isinstance(bias_node, Node):
633            input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
634        if annotate_output:
635            conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
636                input_qspec_map=input_qspec_map,
637                _annotated=True,
638                _is_output_of_quantized_pattern=True,
639            )
640        else:
641            conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
642                input_qspec_map=input_qspec_map,
643                _annotated=True,
644            )
645
646    def _annotate_linear_node_helper(
647        self,
648        linear_node: torch.fx.Node,
649        annotate_output: bool,
650        quantization_config: Optional[QuantizationConfig],
651    ) -> None:
652        """Helper function to annotate the linear node"""
653        if quantization_config is None:
654            _annotate_nodes_not_quantize(linear_node)
655            return
656        input_qspec_map = {}
657        assert linear_node.target in (torch.ops.aten.linear.default,)
658        has_bias = len(linear_node.args) == 3
659        input_index = 0
660        weight_index = 1
661        bias_index = 2
662
663        input_node = linear_node.args[input_index]
664        assert isinstance(input_node, Node)
665        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
666
667        weight_node = linear_node.args[weight_index]
668        assert isinstance(weight_node, Node)
669        input_qspec_map[weight_node] = get_weight_qspec(quantization_config)
670
671        bias_node = linear_node.args[bias_index] if has_bias else None
672        if isinstance(bias_node, Node):
673            input_qspec_map[bias_node] = get_bias_qspec(quantization_config)
674
675        if annotate_output:
676            linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
677                input_qspec_map=input_qspec_map,
678                _annotated=True,
679                _is_output_of_quantized_pattern=True,
680            )
681        else:
682            linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
683                input_qspec_map=input_qspec_map, _annotated=True
684            )
685
686    def _get_output_nodes_of_partitions(
687        self,
688        partition_list: List[SourcePartition],
689    ) -> List[torch.fx.Node]:
690        """Helper function to get the output node list from partition list"""
691        output_node_list = []
692        for partition in partition_list:
693            if len(partition.output_nodes) > 1:
694                raise ValueError("Input partition has more than one output node")
695            output_node = partition.output_nodes[0]
696            assert isinstance(output_node, Node)
697            output_node_list.append(output_node)
698        if len(output_node_list) != len(partition_list):
699            raise ValueError(
700                "length of output_node_list should equal to length of partition_list"
701            )
702        return output_node_list
703
704    def _get_input_idx_for_binary_node(
705        self,
706        conv_gemm_node: torch.fx.Node,
707        binary_node: torch.fx.Node,
708    ):
709        """Helper function to check conv_gemm and extra input node index
710        for binary node fused with conv_gemm.
711        """
712        conv_gemm_node_idx = None
713        extra_input_node_idx = None
714        if (binary_node.args[0].op == "call_function") and (  # type: ignore[union-attr]
715            binary_node.args[0] == conv_gemm_node
716        ):
717            conv_gemm_node_idx = 0
718            extra_input_node_idx = 1
719        elif (binary_node.args[1].op == "call_function") and (  # type: ignore[union-attr]
720            binary_node.args[1] == conv_gemm_node
721        ):
722            conv_gemm_node_idx = 1
723            extra_input_node_idx = 0
724        extra_input_node = binary_node.args[extra_input_node_idx]  # type: ignore[index]
725        assert isinstance(extra_input_node, Node)
726        return conv_gemm_node_idx, extra_input_node_idx
727
728    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
729        """Annotate the given model with quantization configurations.
730
731        Annotation contracts:
732        1. Annotate each node according to the user's qconfig in the following order:
733        `module_name_qconfig`, `operator_type_qconfig`, and `global_config`.
734        2. Avoid re-annotating nodes already annotated in prior stages. For example,
735        if `linear1` has been annotated by `module_name_qconfig`, it won't be annotated again
736        during the processing of the 'operator_type_qconfig' or 'global_config'.
737        3. For config is `None`, the node will be annotated with `_X86InductorQuantizationAnnotation(_annotated=True)`.
738
739        For each pair of (module_name_or_operator_type_or_global, qconfig), a filter function is created.
740        This filter function checks if the node is marked by current stage and not annotated by the previous stage.
741        """
742        for module_name, quantization_config in self.module_name_qconfig.items():
743            self._annotate_with_config(
744                model, quantization_config, _create_module_name_filter(module_name)
745            )
746
747        for operator_type, quantization_config in self.operator_type_qconfig.items():
748            self._annotate_with_config(
749                model, quantization_config, _create_operator_type_filter(operator_type)
750            )
751
752        if self.global_config:
753            self._annotate_with_config(
754                model,
755                self.global_config,
756                _global_config_filter,
757            )
758
759        # Once we've annotated the model with quantization configurations, we also need to annotate
760        # the output of quantizable operations. For example, if we annotated `maxpool2d` to quantize its inputs,
761        # we will quantize its output accordingly. This enables us to fuse the dq-operator-q into a quantized op.
762        # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/
763        # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487
764
765        self._annotate_output_for_int8_in_int8_out_pattern_entry(model)
766
767        return model
768
769    def _annotate_with_config(
770        self,
771        model: torch.fx.GraphModule,
772        quantization_config: Optional[QuantizationConfig],
773        filter_fn: FilterFn,
774    ) -> None:
775        """Annotate the model with the given quantization configuration.
776
777        High-level description of quantization recipe for X86 Inductor Backend:
778        Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively.
779        Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model
780        from start to the end. If a pattern supports computation with int8 data type and inputs connected to
781        quantized patterns, annotate its inputs as quantized pattern.
782        """
783
784        # Step1: Recipe of fusion patterns like conv/linear.
785        self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn)
786        self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn)
787        self._annotate_matmul(model, quantization_config, filter_fn)
788
789        # Step2: Recipe to propagate annotation for patterns beside conv/linear.
790        # Go through all the nodes from start to end.
791        # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/
792        # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538
793
794        self._annotate_propagation_quantizable_pattern_entry(
795            model, quantization_config, filter_fn
796        )
797
798    def _annotate_qat_conv2d_fusion_pattern(
799        self,
800        model: torch.fx.GraphModule,
801        quantization_config: Optional[QuantizationConfig],
802        filter_fn: Optional[FilterFn] = None,
803    ):
804        # Annotate QAT Specific patterns
805        self._annotate_qat_conv2d_bn_binary_unary(model, quantization_config, filter_fn)
806        self._annotate_qat_conv2d_bn_binary(model, quantization_config, filter_fn)
807        self._annotate_qat_conv2d_bn_unary(model, quantization_config, filter_fn)
808        self._annotate_qat_conv2d_bn(model, quantization_config, filter_fn)
809
810    def _annotate_qat_conv2d_bn_binary_unary(
811        self,
812        gm: torch.fx.GraphModule,
813        quantization_config: Optional[QuantizationConfig],
814        filter_fn: Optional[FilterFn] = None,
815    ) -> None:
816        fused_partitions = find_sequential_partitions(
817            gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU]
818        )
819        for fused_partition in fused_partitions:
820            (
821                conv_partition,
822                bn_partition,
823                binary_partition,
824                unary_partition,
825            ) = fused_partition
826
827            (
828                conv_node,
829                bn_output_node,
830                binary_node,
831                unary_node,
832            ) = self._get_output_nodes_of_partitions(
833                [conv_partition, bn_partition, binary_partition, unary_partition]
834            )
835            if len(bn_output_node.users) != 1:
836                # Conv BN pattern should only has 1 user.
837                continue
838            (
839                bn_output_node_idx,
840                extra_input_node_idx,
841            ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node)
842            if (bn_output_node_idx is None) or (extra_input_node_idx is None):
843                continue
844            if bn_output_node != binary_node.args[bn_output_node_idx]:
845                raise ValueError(f"{bn_output_node} doesn't match input of binary node")
846            extra_input_node = binary_node.args[extra_input_node_idx]
847
848            if (
849                conv_node.op != "call_function"
850                or conv_node.target != torch.ops.aten.conv2d.default
851            ):
852                continue
853
854            if _skip_annotate(
855                [unary_node, binary_node, bn_output_node, conv_node], filter_fn
856            ):
857                continue
858
859            self._annotate_conv_node_helper(conv_node, False, quantization_config)
860
861            if quantization_config is not None:
862                binary_node_input_qspec_map = {}
863                binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
864                    quantization_config
865                )
866                binary_node.meta[
867                    QUANT_ANNOTATION_KEY
868                ] = _X86InductorQuantizationAnnotation(
869                    input_qspec_map=binary_node_input_qspec_map,
870                    _annotated=True,
871                )
872                unary_node.meta[
873                    QUANT_ANNOTATION_KEY
874                ] = _X86InductorQuantizationAnnotation(
875                    # TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
876                    output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
877                    _annotated=True,
878                    _is_output_of_quantized_pattern=True,
879                )
880            else:
881                _annotate_nodes_not_quantize([binary_node, unary_node])
882            nodes_to_mark_annotated = list(conv_partition.nodes)
883            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
884            nodes_to_mark_annotated.extend(list(binary_partition.nodes))
885            nodes_to_mark_annotated.extend(list(unary_partition.nodes))
886            _mark_nodes_as_annotated(nodes_to_mark_annotated)
887
888    def _annotate_qat_conv2d_bn_binary(
889        self,
890        gm: torch.fx.GraphModule,
891        quantization_config: Optional[QuantizationConfig],
892        filter_fn: Optional[FilterFn] = None,
893    ) -> None:
894        fused_partitions = find_sequential_partitions(
895            gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add]
896        )
897        for fused_partition in fused_partitions:
898            conv_partition, bn_partition, binary_partition = fused_partition
899            (
900                conv_node,
901                bn_output_node,
902                binary_node,
903            ) = self._get_output_nodes_of_partitions(
904                [conv_partition, bn_partition, binary_partition]
905            )
906            if len(bn_output_node.users) != 1:
907                # Conv BN pattern should only has 1 user.
908                continue
909            (
910                bn_output_node_idx,
911                extra_input_node_idx,
912            ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node)
913            if (bn_output_node_idx is None) or (extra_input_node_idx is None):
914                continue
915            if bn_output_node != binary_node.args[bn_output_node_idx]:
916                raise ValueError(f"{bn_output_node} doesn't match input of binary node")
917
918            extra_input_node = binary_node.args[extra_input_node_idx]
919
920            if (
921                conv_node.op != "call_function"
922                or conv_node.target != torch.ops.aten.conv2d.default
923            ):
924                continue
925
926            if _skip_annotate([binary_node, bn_output_node, conv_node], filter_fn):
927                continue
928
929            self._annotate_conv_node_helper(conv_node, False, quantization_config)
930
931            if quantization_config is not None:
932                binary_node_input_qspec_map = {}
933                binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
934                    quantization_config
935                )
936                binary_node.meta[
937                    QUANT_ANNOTATION_KEY
938                ] = _X86InductorQuantizationAnnotation(
939                    input_qspec_map=binary_node_input_qspec_map,
940                    # TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
941                    output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
942                    _annotated=True,
943                    _is_output_of_quantized_pattern=True,
944                )
945            else:
946                _annotate_nodes_not_quantize(binary_node)
947            nodes_to_mark_annotated = list(conv_partition.nodes)
948            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
949            nodes_to_mark_annotated.extend(list(binary_partition.nodes))
950            _mark_nodes_as_annotated(nodes_to_mark_annotated)
951
952    def _annotate_qat_conv2d_bn_unary(
953        self,
954        gm: torch.fx.GraphModule,
955        quantization_config: Optional[QuantizationConfig],
956        filter_fn: Optional[FilterFn] = None,
957    ) -> None:
958        fused_partitions = []
959        unary_patterns = [
960            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU],
961            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh],
962            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardswish],
963            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6],
964            [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.SiLU],
965        ]
966        for unary_pattern in unary_patterns:
967            partitions = find_sequential_partitions(gm, unary_pattern)
968            if partitions:
969                # Extend the fused_partitions if partitions is not empty
970                fused_partitions.extend(partitions)
971
972        for fused_partition in fused_partitions:
973            conv_partition, bn_partition, unary_partition = fused_partition
974            (
975                conv_node,
976                bn_output_node,
977                unary_node,
978            ) = self._get_output_nodes_of_partitions(
979                [conv_partition, bn_partition, unary_partition]
980            )
981
982            if (
983                conv_node.op != "call_function"
984                or conv_node.target != torch.ops.aten.conv2d.default
985            ):
986                continue
987
988            if _skip_annotate([unary_node, bn_output_node, conv_node], filter_fn):
989                continue
990
991            self._annotate_conv_node_helper(conv_node, False, quantization_config)
992            if quantization_config is not None:
993                unary_node.meta[
994                    QUANT_ANNOTATION_KEY
995                ] = _X86InductorQuantizationAnnotation(
996                    # TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
997                    output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
998                    _annotated=True,
999                    _is_output_of_quantized_pattern=True,
1000                )
1001            else:
1002                _annotate_nodes_not_quantize(unary_node)
1003            nodes_to_mark_annotated = list(conv_partition.nodes)
1004            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
1005            nodes_to_mark_annotated.extend(list(unary_partition.nodes))
1006            _mark_nodes_as_annotated(nodes_to_mark_annotated)
1007
1008    def _annotate_qat_conv2d_bn(
1009        self,
1010        gm: torch.fx.GraphModule,
1011        quantization_config: Optional[QuantizationConfig],
1012        filter_fn: Optional[FilterFn] = None,
1013    ) -> None:
1014        fused_partitions = find_sequential_partitions(
1015            gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d]
1016        )
1017        for fused_partition in fused_partitions:
1018            conv_partition, bn_partition = fused_partition
1019            conv_node, bn_output_node = self._get_output_nodes_of_partitions(
1020                [conv_partition, bn_partition]
1021            )
1022
1023            if (
1024                conv_node.op != "call_function"
1025                or conv_node.target != torch.ops.aten.conv2d.default
1026            ):
1027                continue
1028
1029            if _skip_annotate([bn_output_node, conv_node], filter_fn):
1030                continue
1031
1032            self._annotate_conv_node_helper(conv_node, False, quantization_config)
1033            if quantization_config is not None:
1034                bn_output_node.meta[
1035                    QUANT_ANNOTATION_KEY
1036                ] = _X86InductorQuantizationAnnotation(
1037                    # TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
1038                    output_qspec=get_output_act_qspec(quantization_config),  # type: ignore[arg-type]
1039                    _annotated=True,
1040                    _is_output_of_quantized_pattern=True,
1041                )
1042            else:
1043                _annotate_nodes_not_quantize(bn_output_node)
1044            nodes_to_mark_annotated = list(conv_partition.nodes)
1045            nodes_to_mark_annotated.extend(list(bn_partition.nodes))
1046            _mark_nodes_as_annotated(nodes_to_mark_annotated)
1047
1048    def _annotate_conv2d_fusion_pattern(
1049        self,
1050        model: torch.fx.GraphModule,
1051        quantization_config: Optional[QuantizationConfig],
1052        filter_fn: Optional[FilterFn] = None,
1053    ):
1054        if (quantization_config is None) or (quantization_config.is_qat):
1055            # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat
1056            self._annotate_qat_conv2d_fusion_pattern(
1057                model, quantization_config, filter_fn
1058            )
1059        self._annotate_conv2d_binary_unary(model, quantization_config, filter_fn)
1060        self._annotate_conv2d_binary(model, quantization_config, filter_fn)
1061        self._annotate_conv2d_unary(model, quantization_config, filter_fn)
1062        self._annotate_conv2d(model, quantization_config, filter_fn)
1063
1064    def _annotate_linear_fusion_pattern(
1065        self,
1066        model: torch.fx.GraphModule,
1067        quantization_config: Optional[QuantizationConfig],
1068        filter_fn: Optional[FilterFn] = None,
1069    ):
1070        self._annotate_linear_binary_unary(model, quantization_config, filter_fn)
1071        self._annotate_linear_unary(model, quantization_config, filter_fn)
1072        self._annotate_linear(model, quantization_config, filter_fn)
1073
1074    def _annotate_matmul(
1075        self,
1076        model: torch.fx.GraphModule,
1077        quantization_config: Optional[QuantizationConfig],
1078        filter_fn: Optional[FilterFn] = None,
1079    ):
1080        for node in model.graph.nodes:
1081            if node.target != torch.ops.aten.matmul.default:
1082                continue
1083            if _skip_annotate([node], filter_fn):
1084                continue
1085
1086            if quantization_config is None:
1087                _annotate_nodes_not_quantize(node)
1088                continue
1089
1090            input_qspec_map = {}
1091            matmul_node = node
1092            for input_node in matmul_node.args:
1093                input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
1094            matmul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
1095                input_qspec_map=input_qspec_map,
1096                _annotated=True,
1097                _is_output_of_quantized_pattern=True,
1098            )
1099
1100    def _annotate_conv2d_binary_unary(
1101        self,
1102        gm: torch.fx.GraphModule,
1103        quantization_config: Optional[QuantizationConfig],
1104        filter_fn: Optional[FilterFn] = None,
1105    ) -> None:
1106        # Conv2d + add + unary op
1107        fused_partitions = find_sequential_partitions(
1108            gm, [torch.nn.Conv2d, operator.add, torch.nn.ReLU]
1109        )
1110        for fused_partition in fused_partitions:
1111            conv_partition, binary_partition, unary_partition = fused_partition
1112            conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions(
1113                [conv_partition, binary_partition, unary_partition]
1114            )
1115            if len(conv_node.users) != 1:
1116                # Conv Node should only has 1 user node
1117                continue
1118            conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
1119                conv_node, binary_node
1120            )
1121            if (conv_node_idx is None) or (extra_input_node_idx is None):
1122                continue
1123            if conv_node != binary_node.args[conv_node_idx]:
1124                raise ValueError(f"{conv_node} doesn't match input of binary node")
1125            extra_input_node = binary_node.args[extra_input_node_idx]
1126            if (
1127                conv_node.op != "call_function"
1128                or conv_node.target != torch.ops.aten.conv2d.default
1129            ):
1130                # No conv node found to be fused with add
1131                continue
1132            if _skip_annotate([unary_node, binary_node, conv_node], filter_fn):
1133                continue
1134
1135            if quantization_config is None:
1136                _annotate_nodes_not_quantize([conv_node, binary_node, unary_node])
1137                continue
1138
1139            self._annotate_conv_node_helper(conv_node, False, quantization_config)
1140            binary_node_input_qspec_map = {}
1141            binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
1142                quantization_config
1143            )
1144            binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
1145                input_qspec_map=binary_node_input_qspec_map,
1146                _annotated=True,
1147            )
1148            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
1149                _annotated=True,
1150                _is_output_of_quantized_pattern=True,
1151            )
1152
1153    def _annotate_conv2d_binary(
1154        self,
1155        gm: torch.fx.GraphModule,
1156        quantization_config: Optional[QuantizationConfig],
1157        filter_fn: Optional[FilterFn] = None,
1158    ) -> None:
1159        # Conv2d + add
1160        fused_partitions = find_sequential_partitions(
1161            gm, [torch.nn.Conv2d, operator.add]
1162        )
1163        for fused_partition in fused_partitions:
1164            conv_partition, binary_partition = fused_partition
1165            conv_node, binary_node = self._get_output_nodes_of_partitions(
1166                [conv_partition, binary_partition]
1167            )
1168            if len(conv_node.users) != 1:
1169                # Conv Node should only has 1 user node
1170                continue
1171            conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
1172                conv_node, binary_node
1173            )
1174            if (conv_node_idx is None) or (extra_input_node_idx is None):
1175                continue
1176            if conv_node != binary_node.args[conv_node_idx]:
1177                raise ValueError(f"{conv_node} doesn't match input of binary node")
1178            extra_input_node = binary_node.args[extra_input_node_idx]
1179            assert isinstance(conv_node, Node)
1180            if (
1181                conv_node.op != "call_function"
1182                or conv_node.target != torch.ops.aten.conv2d.default
1183            ):
1184                # No conv node found to be fused with add
1185                continue
1186            if _skip_annotate([binary_node, conv_node], filter_fn):
1187                continue
1188
1189            if quantization_config is None:
1190                _annotate_nodes_not_quantize([conv_node, binary_node])
1191                continue
1192
1193            self._annotate_conv_node_helper(conv_node, False, quantization_config)
1194            binary_node_input_qspec_map = {}
1195            binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
1196                quantization_config
1197            )
1198            binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
1199                input_qspec_map=binary_node_input_qspec_map,
1200                _annotated=True,
1201                _is_output_of_quantized_pattern=True,
1202            )
1203
1204    def _annotate_conv2d_unary(
1205        self,
1206        gm: torch.fx.GraphModule,
1207        quantization_config: Optional[QuantizationConfig],
1208        filter_fn: Optional[FilterFn] = None,
1209    ) -> None:
1210        fused_partitions = []
1211        unary_patterns = [
1212            [torch.nn.Conv2d, torch.nn.ReLU],
1213            [torch.nn.Conv2d, torch.nn.Hardtanh],
1214            [torch.nn.Conv2d, torch.nn.Hardswish],
1215            [torch.nn.Conv2d, torch.nn.ReLU6],
1216            [torch.nn.Conv2d, torch.nn.SiLU],
1217        ]
1218        for unary_pattern in unary_patterns:
1219            partitions = find_sequential_partitions(gm, unary_pattern)
1220            if partitions:
1221                # Extend the fused_partitions if partitions is not empty
1222                fused_partitions.extend(partitions)
1223
1224        for fused_partition in fused_partitions:
1225            conv_partition, unary_partition = fused_partition
1226            conv_node, unary_node = self._get_output_nodes_of_partitions(
1227                [conv_partition, unary_partition]
1228            )
1229            if (
1230                conv_node.op != "call_function"
1231                or conv_node.target != torch.ops.aten.conv2d.default
1232            ):
1233                continue
1234            if _skip_annotate([unary_node, conv_node], filter_fn):
1235                continue
1236
1237            if quantization_config is None:
1238                _annotate_nodes_not_quantize([conv_node, unary_node])
1239                continue
1240
1241            self._annotate_conv_node_helper(conv_node, False, quantization_config)
1242            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
1243                _annotated=True,
1244                _is_output_of_quantized_pattern=True,
1245            )
1246
1247    def _annotate_conv2d(
1248        self,
1249        gm: torch.fx.GraphModule,
1250        quantization_config: Optional[QuantizationConfig],
1251        filter_fn: Optional[FilterFn] = None,
1252    ) -> None:
1253        conv_partitions = get_source_partitions(
1254            gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d]
1255        )
1256        conv_partitions = list(itertools.chain.from_iterable(conv_partitions.values()))
1257        for conv_partition in conv_partitions:
1258            if len(conv_partition.output_nodes) > 1:
1259                raise ValueError("conv partition has more than one output node")
1260            conv_node = conv_partition.output_nodes[0]
1261            if (
1262                conv_node.op != "call_function"
1263                or conv_node.target != torch.ops.aten.conv2d.default
1264            ):
1265                raise ValueError(f"{conv_node} is not an aten conv2d operator")
1266            # skip annotation if it is already annotated
1267            if _skip_annotate([conv_node], filter_fn):
1268                continue
1269            self._annotate_conv_node_helper(conv_node, True, quantization_config)
1270
1271    def _annotate_maxpool2d(
1272        self,
1273        node: Node,
1274        quantization_config: Optional[QuantizationConfig],
1275    ) -> None:
1276        if node.target is not torch.ops.aten.max_pool2d.default:
1277            return
1278        if quantization_config is None:
1279            _annotate_nodes_not_quantize(node)
1280            return
1281
1282        maxpool_node = node
1283        if _is_any_annotated(
1284            [
1285                maxpool_node,
1286            ]
1287        ):
1288            return
1289
1290        input_node = maxpool_node.args[0]
1291        assert isinstance(input_node, Node)
1292        input_qspec_map = {}
1293        input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
1294        maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
1295            input_qspec_map=input_qspec_map,
1296            _annotated=True,
1297            _is_output_of_quantized_pattern=True,
1298        )
1299
1300    def _annotate_cat(
1301        self, node: Node, quantization_config: QuantizationConfig
1302    ) -> None:
1303        if quantization_config is None:
1304            _annotate_nodes_not_quantize(node)
1305            return
1306        cat_node = node
1307        input_nodes = cat_node.args[0]
1308        assert isinstance(input_nodes, Sequence)
1309        first_input_node = input_nodes[0]
1310        input_qspec_map = {}
1311        assert isinstance(first_input_node, Node)
1312        assert isinstance(cat_node, Node)
1313        input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config)
1314        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
1315            (first_input_node, cat_node)
1316        )
1317
1318        for input_node in input_nodes[1:]:
1319            if input_node not in input_qspec_map:
1320                # There has the case of cat same nodes: torch.cat([input0, input0], 1)
1321                assert isinstance(input_node, Node)
1322                input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
1323
1324        cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
1325            input_qspec_map=input_qspec_map,
1326            _annotated=True,
1327            _is_output_of_quantized_pattern=True,
1328        )
1329
1330    def _annotate_propagation_quantizable_pattern_entry(
1331        self,
1332        gm: torch.fx.GraphModule,
1333        quantization_config: Optional[QuantizationConfig],
1334        filter_fn: Optional[FilterFn] = None,
1335    ):
1336        for node in gm.graph.nodes:
1337            self._annotate_propagation_quantizable_pattern(
1338                node, quantization_config, filter_fn
1339            )
1340
1341    def _annotate_propagation_quantizable_pattern(
1342        self, node: Node, quantization_config, filter_fn
1343    ) -> None:
1344        # Propagate annotation to quantizable patterns.
1345        if (
1346            (node.target in propagation_quantizable_ops)
1347            and (not _is_any_annotated([node]))
1348            and (node.op == "call_function")
1349        ):
1350
1351            def is_all_inputs_connected_to_quantized_op(input_nodes):
1352                # Ensure all the inputs connect to fusion pattern or quantized node
1353                for input_node in input_nodes:
1354                    if not _is_quantized_op_pt2e(input_node):
1355                        return False
1356                return True
1357
1358            if _skip_annotate([node], filter_fn):
1359                return
1360
1361            if quantization_config is None:
1362                _annotate_nodes_not_quantize(node)
1363                return
1364
1365            if node.target is torch.ops.aten.max_pool2d.default:
1366                # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not
1367                input_nodes_to_check = [node.all_input_nodes[0]]
1368                if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
1369                    if quantization_config is not None:
1370                        warnings.warn(
1371                            f"The input of maxpool2d is not quantized, skip annotate maxpool2d with config {quantization_config}."
1372                        )
1373                    return
1374
1375                self._annotate_maxpool2d(node, quantization_config)
1376                return
1377            elif node.target is torch.ops.aten.cat.default:
1378                input_nodes_to_check = node.all_input_nodes
1379                if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
1380                    return
1381                self._annotate_cat(node, quantization_config)
1382            else:
1383                input_node = node.all_input_nodes[0]
1384                if not is_all_inputs_connected_to_quantized_op(
1385                    [
1386                        input_node,
1387                    ]
1388                ):
1389                    return
1390                input_qspec_map = {}
1391                input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
1392                node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
1393                    input_qspec_map=input_qspec_map,
1394                    _annotated=True,
1395                    _is_output_of_quantized_pattern=True,
1396                )
1397        return
1398
1399    def _annotate_output_share_observer_as_input(
1400        self, input_node: Node, source_node: Node
1401    ):
1402        source_node_quantization_annotation = (
1403            source_node.meta[QUANT_ANNOTATION_KEY]
1404            if QUANT_ANNOTATION_KEY in source_node.meta
1405            else None
1406        )
1407        if (
1408            source_node_quantization_annotation
1409            and source_node_quantization_annotation._is_output_of_quantized_pattern
1410        ):
1411            edge_or_node = (input_node, source_node)
1412            source_node_quantization_annotation.output_qspec = SharedQuantizationSpec(
1413                edge_or_node
1414            )
1415        return
1416
1417    def _annotate_output_for_int8_in_int8_out_pattern_entry(
1418        self,
1419        model: torch.fx.GraphModule,
1420    ):
1421        for node in model.graph.nodes:
1422            self._annotate_output_for_int8_in_int8_out_pattern(node)
1423
1424    def _annotate_output_for_int8_in_int8_out_pattern(
1425        self,
1426        node: Node,
1427    ) -> None:
1428        r"""
1429        Check and insert observer at output of node in int8_in_int8_out_ops if needed.
1430        Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/
1431        90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495
1432        """
1433        edge_or_node: Tuple[Node, Node]
1434        if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])):
1435            if node.target == torch.ops.aten.max_pool2d.default:
1436                maxpool_node = node
1437                if not _is_all_annotated(
1438                    [
1439                        maxpool_node,
1440                    ]
1441                ):
1442                    return
1443
1444                # Get the quantization_annotation from getitem_node
1445                maxpool_node_quantization_annotation = (
1446                    maxpool_node.meta[QUANT_ANNOTATION_KEY]
1447                    if QUANT_ANNOTATION_KEY in maxpool_node.meta
1448                    else None
1449                )
1450                if (
1451                    maxpool_node_quantization_annotation
1452                    and maxpool_node_quantization_annotation._is_output_of_quantized_pattern
1453                ):
1454                    # Annotate the output_qspec of getitem_node
1455                    input_act = maxpool_node.args[0]
1456                    assert isinstance(input_act, Node)
1457                    assert isinstance(maxpool_node, Node)
1458                    edge_or_node = (input_act, maxpool_node)
1459                    maxpool_node_quantization_annotation.output_qspec = (
1460                        SharedQuantizationSpec(edge_or_node)
1461                    )
1462            else:
1463                input_node = node.all_input_nodes[0]
1464                self._annotate_output_share_observer_as_input(input_node, node)
1465        return
1466
1467    def _annotate_linear(
1468        self,
1469        gm: torch.fx.GraphModule,
1470        quantization_config: Optional[QuantizationConfig],
1471        filter_fn: Optional[FilterFn] = None,
1472    ) -> None:
1473        linear_partitions = get_source_partitions(
1474            gm.graph, [torch.nn.Linear, torch.nn.functional.linear]
1475        )
1476        linear_partitions = list(
1477            itertools.chain.from_iterable(linear_partitions.values())
1478        )
1479        for partition in linear_partitions:
1480            if len(partition.output_nodes) > 1:
1481                raise ValueError(
1482                    "Linear partition cannot have more than one output node"
1483                )
1484            linear_node = partition.output_nodes[0]
1485            if linear_node.op != "call_function" or linear_node.target not in (
1486                torch.ops.aten.linear.default,
1487            ):
1488                raise ValueError(f"{linear_node} is not an aten linear operator")
1489            # skip annotation if it is already annotated
1490            if _skip_annotate([linear_node], filter_fn):
1491                continue
1492            self._annotate_linear_node_helper(linear_node, True, quantization_config)
1493
1494    def _annotate_linear_unary(
1495        self,
1496        gm: torch.fx.GraphModule,
1497        quantization_config: Optional[QuantizationConfig],
1498        filter_fn: Optional[FilterFn] = None,
1499    ) -> None:
1500        postop_list = [
1501            torch.nn.ReLU,
1502            torch.nn.LeakyReLU,
1503            torch.nn.Tanh,
1504            torch.nn.GELU,
1505        ]
1506        fused_partitions: List[tuple] = []
1507        for postop in postop_list:
1508            fused_partitions = fused_partitions + find_sequential_partitions(
1509                gm, [torch.nn.Linear, postop]
1510            )
1511        for fused_partition in fused_partitions:
1512            linear_partition, unary_partition = fused_partition
1513            linear_node, unary_node = self._get_output_nodes_of_partitions(
1514                [linear_partition, unary_partition]
1515            )
1516            if linear_node.op != "call_function" or linear_node.target not in (
1517                torch.ops.aten.linear.default,
1518            ):
1519                continue
1520            if _skip_annotate([unary_node, linear_node], filter_fn):
1521                continue
1522
1523            if quantization_config is None:
1524                _annotate_nodes_not_quantize([linear_node, unary_node])
1525                continue
1526
1527            self._annotate_linear_node_helper(linear_node, False, quantization_config)
1528            unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
1529                _annotated=True,
1530                _is_output_of_quantized_pattern=True,
1531            )
1532
1533    def _annotate_linear_binary_unary(
1534        self,
1535        gm: torch.fx.GraphModule,
1536        quantization_config: Optional[QuantizationConfig],
1537        filter_fn: Optional[FilterFn] = None,
1538    ) -> None:
1539        # linear + binary_op + (optional) unary op
1540        binary_op_list = [operator.add]
1541        unary_op_list = [torch.nn.ReLU, None]
1542        combinations = itertools.product(binary_op_list, unary_op_list)
1543        for binary_op, unary_op in combinations:
1544            has_unary = unary_op is not None
1545            seq_partition = [torch.nn.Linear, binary_op]
1546            if has_unary:
1547                seq_partition.append(unary_op)
1548            fused_partitions = find_sequential_partitions(gm, seq_partition)
1549            for fused_partition in fused_partitions:
1550                unary_partition, unary_node = None, None
1551                if has_unary:
1552                    (
1553                        linear_partition,
1554                        binary_partition,
1555                        unary_partition,
1556                    ) = fused_partition
1557                    (
1558                        linear_node,
1559                        binary_node,
1560                        unary_node,
1561                    ) = self._get_output_nodes_of_partitions(
1562                        [linear_partition, binary_partition, unary_partition]
1563                    )
1564                else:
1565                    linear_partition, binary_partition = fused_partition
1566                    linear_node, binary_node = self._get_output_nodes_of_partitions(
1567                        [linear_partition, binary_partition]
1568                    )
1569                if len(linear_node.users) != 1:
1570                    # Linear Node should only has 1 user node
1571                    continue
1572                (
1573                    linear_node_idx,
1574                    extra_input_node_idx,
1575                ) = self._get_input_idx_for_binary_node(linear_node, binary_node)
1576                if (linear_node_idx is None) or (extra_input_node_idx is None):
1577                    continue
1578                if linear_node != binary_node.args[linear_node_idx]:
1579                    raise ValueError(
1580                        f"{linear_node} doesn't match input of binary node"
1581                    )
1582                assert isinstance(linear_node, Node)
1583                if (
1584                    linear_node.op != "call_function"
1585                    or linear_node.target != torch.ops.aten.linear.default
1586                ):
1587                    # No linear node found to be fused with add
1588                    continue
1589                node_list = (
1590                    [binary_node, linear_node]
1591                    if unary_node is None
1592                    else [unary_node, binary_node, linear_node]
1593                )
1594                if _skip_annotate(node_list, filter_fn):
1595                    continue
1596
1597                if quantization_config is None:
1598                    _annotate_nodes_not_quantize(node_list)
1599                    continue
1600
1601                self._annotate_linear_node_helper(
1602                    linear_node, False, quantization_config
1603                )
1604                # We don't insert q-dq before the binary input node due to accuracy issues
1605                binary_node.meta[
1606                    QUANT_ANNOTATION_KEY
1607                ] = _X86InductorQuantizationAnnotation(
1608                    input_qspec_map={},
1609                    _annotated=True,
1610                    _is_output_of_quantized_pattern=(not has_unary),
1611                )
1612                if unary_node is not None:
1613                    unary_node.meta[
1614                        QUANT_ANNOTATION_KEY
1615                    ] = _X86InductorQuantizationAnnotation(
1616                        _annotated=True,
1617                        _is_output_of_quantized_pattern=True,
1618                    )
1619
1620    def validate(self, model: torch.fx.GraphModule) -> None:
1621        pass
1622
1623    @classmethod
1624    def get_supported_operators(cls) -> List[OperatorConfig]:
1625        return cls.supported_config_and_operators
1626