xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/_lower_to_native_backend.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
4
5import torch
6import torch.ao.nn.intrinsic as nni
7import torch.ao.nn.intrinsic.quantized as nniq
8import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
9import torch.ao.nn.quantized as nnq
10import torch.ao.nn.quantized.dynamic as nnqd
11import torch.ao.nn.quantized.reference as nnqr
12import torch.nn as nn
13import torch.nn.functional as F
14from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
15from torch.ao.quantization.qconfig import QConfigAny
16from torch.ao.quantization.quantization_mappings import get_quantized_operator
17from torch.ao.quantization.utils import _parent_name
18from torch.fx import GraphModule, map_arg, Node
19from torch.fx.graph import Graph
20
21from .utils import (
22    collect_producer_nodes,
23    create_node_from_old_node_preserve_meta,
24    get_linear_prepack_op_for_dtype,
25    get_new_attr_name_with_prefix,
26    get_qconv_prepack_op,
27    graph_module_from_producer_nodes,
28)
29
30
31QOP_TO_ARG_NAMES_TO_SKIP = {
32    torch._ops.ops.quantized.hardswish: ["inplace"],
33    torch._ops.ops.quantized.elu: ["inplace"],
34    torch._ops.ops.quantized.dropout: ["inplace"],
35    torch._ops.ops.quantized.instance_norm: [
36        "running_mean",
37        "running_var",
38        "use_input_stats",
39        "momentum",
40    ],
41}
42
43
44def _is_node_in_list(node, modules, func_list, method_list, module_type_list):
45    is_call_function = node.op == "call_function" and node.target in func_list
46    is_call_method = node.op == "call_method" and node.target in method_list
47    is_call_module = (
48        node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
49    )
50    return is_call_function, is_call_method, is_call_module
51
52
53def is_fixed_qparams_node(node, modules):
54    func_list = [
55        torch.nn.functional.hardsigmoid,
56        torch.nn.functional.sigmoid,
57        torch.sigmoid,
58        torch.tanh,
59    ]
60    method_list = [
61        "hardsigmoid",
62        "hardsigmoid_",
63        "sigmoid",
64        "sigmoid_",
65        "tanh",
66        "tanh_",
67    ]
68    module_type_list = [
69        torch.nn.Hardsigmoid,
70        torch.nn.Sigmoid,
71        torch.nn.Tanh,
72        torch.nn.Softmax,
73    ]
74    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
75
76
77def is_default_node(node, modules):
78    func_list = [
79        torch.nn.functional.elu,
80        torch.nn.functional.hardswish,
81        torch.nn.functional.instance_norm,
82        torch.nn.functional.layer_norm,
83        torch.nn.functional.leaky_relu,
84        torch.nn.functional.dropout,
85    ]
86    method_list: List[Any] = []
87    module_type_list = [
88        nnqr.ConvTranspose1d,
89        nnqr.ConvTranspose2d,
90        nnqr.ConvTranspose3d,
91        torch.nn.ELU,
92        torch.nn.LeakyReLU,
93        torch.nn.Hardswish,
94        torch.nn.InstanceNorm1d,
95        torch.nn.InstanceNorm2d,
96        torch.nn.InstanceNorm3d,
97        torch.nn.LayerNorm,
98        torch.nn.Dropout,
99        torch.nn.PReLU,
100        torch.nn.BatchNorm2d,
101        torch.nn.BatchNorm3d,
102        torch.ao.nn.intrinsic.BNReLU2d,
103        torch.ao.nn.intrinsic.BNReLU3d,
104    ]
105    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
106
107
108def is_copy_node(node, modules):
109    func_list = [
110        torch.adaptive_avg_pool1d,
111        torch.nn.functional.adaptive_avg_pool2d,
112        torch.nn.functional.adaptive_avg_pool3d,
113        torch.nn.functional.hardtanh,
114        torch.nn.functional.hardtanh_,
115        torch.nn.functional.interpolate,
116        torch.nn.functional.max_pool1d,
117        torch.nn.functional.max_pool2d,
118        torch.nn.functional.max_pool3d,
119        torch.nn.functional.relu,
120        torch.nn.functional.relu6,
121        torch.avg_pool1d,
122        torch._C._nn.avg_pool2d,
123        torch._C._nn.avg_pool3d,
124        torch.clamp,
125        torch.flatten,
126        torch.mean,
127        operator.floordiv,
128        # F.channel_shuffle and torch.channel_shuffle are essentially the same thing
129        # so we only need to put one of them here
130        torch.channel_shuffle,
131    ]
132    method_list = [
133        "clamp",
134        "mean",
135        "relu",
136        "relu_",
137    ]
138    module_type_list = [
139        torch.nn.AdaptiveAvgPool1d,
140        torch.nn.AdaptiveAvgPool2d,
141        torch.nn.AdaptiveAvgPool3d,
142        torch.nn.AvgPool1d,
143        torch.nn.AvgPool2d,
144        torch.nn.AvgPool3d,
145        torch.nn.Hardtanh,
146        torch.nn.MaxPool1d,
147        torch.nn.MaxPool2d,
148        torch.nn.MaxPool3d,
149        torch.nn.ReLU,
150        torch.nn.ReLU6,
151        torch.nn.ChannelShuffle,
152    ]
153    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
154
155
156def is_general_tensor_shape_node(node, modules):
157    func_list = [
158        torch.narrow,
159        torch.transpose,
160        torch.repeat_interleave,
161        torch.squeeze,
162        torch.stack,
163        torch.unsqueeze,
164        torch.nn.functional.pixel_shuffle,
165        torch.nn.functional.pixel_unshuffle,
166    ]
167    method_list = [
168        "contiguous",
169        "detach",
170        "detach_",
171        "permute",
172        "repeat",
173        "repeat_interleave",
174        "reshape",
175        "resize_",
176        "shape",
177        "size",
178        "squeeze",
179        "squeeze_",
180        "transpose",
181        "unsqueeze",
182        "unsqueeze_",
183        "view",
184    ]
185    module_type_list = [
186        torch.nn.Identity,
187        torch.nn.PixelShuffle,
188        torch.nn.PixelUnshuffle,
189    ]
190    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
191
192
193def is_other_node(node, modules):
194    func_list = [
195        torch.cat,
196    ]
197    method_list: List[Any] = []
198    module_type_list: List[Any] = []
199    return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
200
201
202def is_special_pattern_node(node, modules):
203    res_function, res_method, res_module = False, False, False
204    for checker in [
205        is_fixed_qparams_node,
206        is_default_node,
207        is_copy_node,
208        is_general_tensor_shape_node,
209        is_other_node,
210    ]:
211        is_call_function, is_call_method, is_call_module = checker(node, modules)
212        res_function = res_function or is_call_function
213        res_method = res_method or is_call_method
214        res_module = res_module or is_call_module
215    return res_function, res_method, res_module
216
217
218def is_dequantize_node(node):
219    return (
220        isinstance(node, Node)
221        and node.op == "call_method"
222        and node.target == "dequantize"
223    )
224
225
226def is_getattr_tensor_metadata_node(node):
227    return (
228        node.op == "call_function"
229        and node.target == getattr
230        and node.args[1] in ["shape"]
231    )
232
233
234def is_get_tensor_info_node(node):
235    return node.op == "call_method" and node.target in ["shape", "size"]
236
237
238def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigAny]):
239    """
240    Return True if the op is configured with a None qconfig, False otherwise.
241    Note: maybe need to generalize this to also check for the dtype, and we
242    only lower when dtype matches, but right now fbgemm/qnnpack only support
243    a single dtype, so it is OK for now.
244    """
245    return op.name in qconfig_map and qconfig_map[op.name] is None
246
247
248# Mapping from reference module class to the replacement static quantized module class for lowering
249STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[WeightedQuantizedModule]] = {
250    nnqr.Linear: nnq.Linear,
251    nnqr.Conv1d: nnq.Conv1d,
252    nnqr.Conv2d: nnq.Conv2d,
253    nnqr.Conv3d: nnq.Conv3d,
254}
255
256# Mapping from reference module class to the replacement dynamic quantized module class for lowering
257DYNAMIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[nn.Module]] = {
258    nnqr.Linear: nnqd.Linear,
259    nnqr.GRUCell: nnqd.GRUCell,
260    nnqr.LSTMCell: nnqd.LSTMCell,
261    nnqr.RNNCell: nnqd.RNNCell,
262    nnqr.LSTM: nnqd.LSTM,
263    nnqr.GRU: nnqd.GRU,
264}
265
266# Mapping from reference module class to the replacement weight only quantized module class for lowering
267# TODO: correct the namespace for these modules
268WEIGHT_ONLY_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[nn.Module]] = {
269    nnqr.Embedding: nnq.Embedding,
270    nnqr.EmbeddingBag: nnq.EmbeddingBag,
271}
272
273# TODO: merge with STATIC_LOWER_MODULE_MAP after we merge
274# _lower_static_weighted_ref_module and special_pattern_replacement
275SPECIAL_PATTERN_LOWER_MODULE_MAP = {
276    nn.BatchNorm2d: nnq.BatchNorm2d,
277    nn.BatchNorm3d: nnq.BatchNorm3d,
278    nnqr.ConvTranspose1d: nnq.ConvTranspose1d,
279    nnqr.ConvTranspose2d: nnq.ConvTranspose2d,
280    nnqr.ConvTranspose3d: nnq.ConvTranspose3d,
281    nn.ELU: nnq.ELU,
282    nn.LeakyReLU: nnq.LeakyReLU,
283    nn.Hardswish: nnq.Hardswish,
284    nn.InstanceNorm1d: nnq.InstanceNorm1d,
285    nn.InstanceNorm2d: nnq.InstanceNorm2d,
286    nn.InstanceNorm3d: nnq.InstanceNorm3d,
287    nn.LayerNorm: nnq.LayerNorm,
288    nn.Dropout: nnq.Dropout,
289    nn.Softmax: nnq.Softmax,
290    nn.PReLU: nnq.PReLU,
291    nni.BNReLU2d: nniq.BNReLU2d,
292    nni.BNReLU3d: nniq.BNReLU3d,
293}
294
295# Mapping from fused module class to a 2-tuple of:
296#   1) The inner reference module class
297#   2) The replacement static quantized module class for lowering
298STATIC_LOWER_FUSED_MODULE_MAP: Dict[
299    Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]
300] = {
301    nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU),
302    # TODO: LinearLeakyReLU is registered as global but it is only fused and
303    # lowered when ondnn's backend config is used. Maybe need to separate
304    # registration and lowering functions for different backends in the future.
305    nni.LinearLeakyReLU: (nnqr.Linear, nniq.LinearLeakyReLU),
306    nni.LinearTanh: (nnqr.Linear, nniq.LinearTanh),
307    nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d),
308    nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d),
309    nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d),
310}
311
312# The difference between STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP and STATIC_LOWER_FUSED_MODULE_MAP:
313# The refer node inside STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP has 2 inputs.
314# Mapping from fused module class to a 2-tuple of:
315#   1) The inner reference module class
316#   2) The replacement static quantized module class for lowering
317STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: Dict[
318    Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]
319] = {
320    nni.ConvAdd2d: (nnqr.Conv2d, nniq.ConvAdd2d),
321    nni.ConvAddReLU2d: (nnqr.Conv2d, nniq.ConvAddReLU2d),
322}
323
324# Mapping from fused module class to a 2-tuple of:
325#   1) The inner reference module class
326#   2) The replacement dynamic quantized module class for lowering
327DYNAMIC_LOWER_FUSED_MODULE_MAP: Dict[
328    Type[nn.Module], Tuple[Type[nn.Module], Type[nn.Module]]
329] = {
330    nni.LinearReLU: (nnqr.Linear, nniqd.LinearReLU),
331}
332
333# Mapping from a functional to lower to a 2-tuple of
334#   1) The quantized version of the op
335#   2) The quantized version of the op fused with relu, if it exists, else None
336STATIC_LOWER_FUNCTIONAL_MAP: Dict[Callable, Tuple[Callable, Optional[Callable]]] = {
337    F.linear: (torch.ops.quantized.linear, torch.ops.quantized.linear_relu),
338    F.conv1d: (torch.ops.quantized.conv1d, torch.ops.quantized.conv1d_relu),
339    F.conv2d: (torch.ops.quantized.conv2d, torch.ops.quantized.conv2d_relu),
340    F.conv3d: (torch.ops.quantized.conv3d, torch.ops.quantized.conv3d_relu),
341    F.conv_transpose1d: (torch.ops.quantized.conv_transpose1d, None),
342    F.conv_transpose2d: (torch.ops.quantized.conv_transpose2d, None),
343    F.conv_transpose3d: (torch.ops.quantized.conv_transpose3d, None),
344}
345
346WEIGHT_PREPACK_OPS: Set[Callable] = {
347    torch._ops.ops.quantized.linear_prepack,
348    torch._ops.ops.quantized.linear_prepack_fp16,
349    torch._ops.ops.quantized.conv1d_prepack,
350    torch._ops.ops.quantized.conv2d_prepack,
351    torch._ops.ops.quantized.conv3d_prepack,
352    torch.ops.quantized.conv_transpose1d_prepack,
353    torch.ops.quantized.conv_transpose2d_prepack,
354    torch.ops.quantized.conv_transpose3d_prepack,
355}
356
357# Mapping from a functional to a dictionary, where the key is a 2-tuple of
358# (input_activation_dtype, weight_dtype) and the value is a 2-tuple of
359#   1) The dynamically quantized version of the op
360#   2) The dynamically quantized version of the op fused with relu, if it exists, else None
361DYNAMIC_LOWER_FUNCTIONAL_MAP: Dict[
362    Callable, Dict[Tuple[torch.dtype, torch.dtype], Tuple[Callable, Optional[Callable]]]
363] = {
364    F.linear: {
365        (torch.quint8, torch.qint8): (
366            torch.ops.quantized.linear_dynamic,
367            torch.ops.quantized.linear_relu_dynamic,
368        ),
369        (torch.float16, torch.float16): (
370            torch.ops.quantized.linear_dynamic_fp16,
371            torch.ops.quantized.linear_relu_dynamic_fp16,
372        ),
373    },
374    # dynamic conv + relu is not available yet
375    F.conv1d: {
376        (torch.quint8, torch.qint8): (torch.ops.quantized.conv1d_dynamic, None),
377    },
378    F.conv2d: {
379        (torch.quint8, torch.qint8): (torch.ops.quantized.conv2d_dynamic, None),
380    },
381    F.conv3d: {
382        (torch.quint8, torch.qint8): (torch.ops.quantized.conv3d_dynamic, None),
383    },
384}
385
386CONV_FUNCTIONAL_OPS: Set[Callable] = {
387    F.conv1d,
388    F.conv2d,
389    F.conv3d,
390}
391
392CONV_TRANSPOSE_FUNCTIONAL_OPS: Set[Callable] = {
393    F.conv_transpose1d,
394    F.conv_transpose2d,
395    F.conv_transpose3d,
396}
397
398# TODO: add tests for lowering these ops
399QBIN_OP_MAPPING: Dict[Union[Callable, str], Callable] = {
400    operator.add: torch.ops.quantized.add,
401    torch.add: torch.ops.quantized.add,
402    operator.mul: torch.ops.quantized.mul,
403    operator.matmul: torch.ops.quantized.matmul,
404    torch.mul: torch.ops.quantized.mul,
405    torch.matmul: torch.ops.quantized.matmul,
406}
407QBIN_RELU_OP_MAPPING: Dict[Union[Callable, str], Callable] = {
408    operator.add: torch.ops.quantized.add_relu,
409    torch.add: torch.ops.quantized.add_relu,
410    operator.mul: torch.ops.quantized.mul_relu,
411    torch.mul: torch.ops.quantized.mul_relu,
412}
413
414
415def _save_packed_weight(self, destination, prefix, keep_vars):
416    for attr_name in dir(self):
417        if "_packed_weight" in attr_name and isinstance(
418            getattr(self, attr_name), torch._C.ScriptObject
419        ):  # type: ignore[attr-defined]
420            packed_weight = getattr(self, attr_name)
421            destination[prefix + attr_name] = packed_weight
422
423
424def _load_packed_weight(
425    self,
426    state_dict,
427    prefix,
428    local_metadata,
429    strict,
430    missing_keys,
431    unexpected_keys,
432    error_msgs,
433):
434    attrs_to_pop = []
435    for attr_name in state_dict:
436        if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject):  # type: ignore[attr-defined] # noqa: B950
437            setattr(self, attr_name, state_dict[attr_name])
438            attrs_to_pop.append(attr_name)
439
440    # pop the packed param attributesn
441    for attr_name in attrs_to_pop:
442        state_dict.pop(attr_name)
443
444
445def fold_weight(
446    quantized_model: GraphModule, node_name_to_scope: Dict[str, Tuple[str, type]]
447) -> GraphModule:
448    """
449    Trace back from the weight node util we hit getattr, reconstruct the
450    graph module with the traced nodes and run the graph module to pack the
451    weight. then replace the original chain of ops with the packed weight.
452    """
453    packed_weights = {}
454    # map from folded node name to the prepacked weight name
455    folded_nodes = {}
456    # get packed weights
457    for node in quantized_model.graph.nodes:
458        if node.op == "call_function" and node.target in WEIGHT_PREPACK_OPS:
459            nodes_to_fold = collect_producer_nodes(node)
460            if nodes_to_fold is not None:
461                for node_to_fold in nodes_to_fold:
462                    folded_nodes[node_to_fold.name] = node
463
464                prepacking_module = graph_module_from_producer_nodes(
465                    quantized_model, nodes_to_fold
466                )
467                packed_weight = prepacking_module()
468                packed_weights[node.name] = packed_weight
469
470    # remove folded nodes and replace the prepacking node with getattr
471    folded_graph = Graph()
472    env: Dict[Any, Any] = {}
473
474    def load_arg(a):
475        return map_arg(a, lambda node: env[node.name])
476
477    for node in quantized_model.graph.nodes:
478        prepack_node = folded_nodes.get(node.name, None)
479        if prepack_node is node:
480            packed_weight = packed_weights[node.name]
481            # add a prepacked attribute to root
482            op_node = next(iter(prepack_node.users))
483            module_path, _ = node_name_to_scope[op_node.name]
484            get_new_packed_weight_name = get_new_attr_name_with_prefix(
485                module_path + "_packed_weight_"
486            )
487            packed_weight_name = get_new_packed_weight_name(quantized_model)
488            setattr(quantized_model, packed_weight_name, packed_weight)
489            # replace prepack node with a getattr node
490            env[node.name] = folded_graph.create_node(
491                "get_attr", packed_weight_name, (), {}
492            )
493        elif prepack_node is not None:
494            # remove the foled node
495            continue
496        else:
497            # copy other nodes
498            env[node.name] = folded_graph.node_copy(node, load_arg)
499
500    quantized_model = GraphModule(quantized_model, folded_graph)
501    quantized_model._register_state_dict_hook(_save_packed_weight)
502    quantized_model.register_load_state_dict_pre_hook(_load_packed_weight)
503    return quantized_model
504
505
506def _get_module(node: Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]:
507    """
508    Return the `torch.nn.Module` that corresponds to the specified node's target.
509    If no such node exists, return None.
510    """
511    if node.op == "call_module" and str(node.target) in modules:
512        return modules[str(node.target)]
513    else:
514        return None
515
516
517def _match_static_pattern(
518    node: Node,
519    modules: Dict[str, nn.Module],
520    qconfig_map: Dict[str, QConfigAny],
521    matching_modules_or_ops: List[Callable],
522    dequantize_node_arg_indices: List[int],
523) -> Union[Tuple[Node, Node, Node], Tuple[None, None, None]]:
524    """
525    Match the pattern (dequantize - ref node - quantize) against the node provided.
526
527    If there is a match, return a 3-tuple of:
528      1) q_node: the quantize node,
529      2) relu_node: a relu node wrapping the ref_node, and
530      3) ref_node: a reference module or functional node to replace with its quantized counterpart
531    Otherwise, if there is no match, return a 3-tuple of (None, None, None).
532
533    Parameters:
534      node: The `torch.fx.Node` to match against.
535      modules: A mapping from node names to modules in the model graph, used for module lookup.
536      qconfig_map: A mapping from node names to the qconfigs associated with the nodes.
537          If the corresponding qconfig for the reference node is None, then return no match.
538      matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s.
539          If the reference node is not in this list, then return no match.
540      dequantize_node_arg_indices: A list of indices in the reference node args where dequantize
541          nodes may be present. An empty list means skipping the check for dequantize nodes.
542    """
543    SKIP_LOWERING_VALUE = (None, None, None)
544
545    # Match quantize node
546    if node.op != "call_function" or node.target != torch.quantize_per_tensor:
547        return SKIP_LOWERING_VALUE
548    q_node = node
549    ref_node = q_node.args[0]
550    assert isinstance(ref_node, Node)
551
552    # Handle cases where the node is wrapped in a ReLU
553    if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or (
554        ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU
555    ):
556        relu_node = ref_node
557        ref_node = relu_node.args[0]
558        assert isinstance(ref_node, Node)
559    else:
560        relu_node = None
561    if should_skip_lowering(ref_node, qconfig_map):
562        return SKIP_LOWERING_VALUE
563
564    # Match reference module or functional
565    if isinstance(matching_modules_or_ops[0], type) and issubclass(
566        matching_modules_or_ops[0], nn.Module
567    ):
568        expected_op = "call_module"
569        match_key = type(_get_module(ref_node, modules))
570    else:
571        expected_op = "call_function"
572        match_key = ref_node.target  # type: ignore[assignment]
573    if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
574        return SKIP_LOWERING_VALUE
575
576    # Match dequantize node(s). Both of the following conditions must pass:
577    # (1) All `torch.fx.Node`s at the matching indices must be a dequantize node
578    # (2) There must be at least one dequantize node
579    matched_dequantize = False
580    for i in dequantize_node_arg_indices:
581        assert i < len(
582            ref_node.args
583        ), f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}"
584        arg = ref_node.args[i]
585        if is_dequantize_node(arg):
586            matched_dequantize = True
587        elif isinstance(arg, Node):
588            return SKIP_LOWERING_VALUE
589    if not matched_dequantize:
590        return SKIP_LOWERING_VALUE
591
592    return (q_node, relu_node, ref_node)  # type: ignore[return-value]
593
594
595def _match_static_pattern_with_two_inputs(
596    node: Node,
597    modules: Dict[str, nn.Module],
598    qconfig_map: Dict[str, QConfigAny],
599    matching_modules_or_ops: List[Callable],
600) -> Union[Tuple[Node, Node], Tuple[None, None]]:
601    """
602                      (dequantize \
603    Match the pattern (dequantize - ref node - quantize) against the node provided.
604
605    If there is a match, return a 2-tuple of:
606      1) q_node: the quantize node,
607      2) ref_node: a reference module or functional node to replace with its quantized counterpart
608    Otherwise, if there is no match, return a 2-tuple of (None, None).
609
610    Parameters:
611      node: The `torch.fx.Node` to match against.
612      modules: A mapping from node names to modules in the model graph, used for module lookup.
613      qconfig_map: A mapping from node names to the qconfigs associated with the nodes.
614          If the corresponding qconfig for the reference node is None, then return no match.
615      matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s.
616          If the reference node is not in this list, then return no match.
617    """
618    SKIP_LOWERING_VALUE = (None, None)
619
620    # Match quantize node
621    if node.op != "call_function" or node.target != torch.quantize_per_tensor:
622        return SKIP_LOWERING_VALUE
623    q_node = node
624    ref_node = q_node.args[0]
625    assert isinstance(ref_node, Node)
626
627    if should_skip_lowering(ref_node, qconfig_map):
628        return SKIP_LOWERING_VALUE
629
630    # Match reference module or functional
631    if isinstance(matching_modules_or_ops[0], type) and issubclass(
632        matching_modules_or_ops[0], nn.Module
633    ):
634        expected_op = "call_module"
635        match_key = type(_get_module(ref_node, modules))
636    else:
637        # This pass only support op of "call_module"
638        return SKIP_LOWERING_VALUE
639
640    if ref_node.op != expected_op or match_key not in matching_modules_or_ops:
641        return SKIP_LOWERING_VALUE
642
643    # Check ref_node has 2 input nodes, both are dq node.
644    if len(ref_node.args) != 2:
645        return SKIP_LOWERING_VALUE
646    for i in range(len(ref_node.args)):
647        arg = ref_node.args[i]
648        if not is_dequantize_node(arg):
649            return SKIP_LOWERING_VALUE
650
651    return (q_node, ref_node)
652
653
654def _lower_static_weighted_ref_module(
655    model: GraphModule, qconfig_map: Dict[str, QConfigAny]
656):
657    """
658    Traverse the graph and find dequantize - ref module - quantize patterns
659    and replace them with the quantized version of the ref module.
660    """
661    modules = dict(model.named_modules(remove_duplicate=False))
662    nodes = list(model.graph.nodes)
663    for n in model.graph.nodes:
664        # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize)
665        matching_modules = list(STATIC_LOWER_MODULE_MAP.keys()) + list(
666            STATIC_LOWER_FUSED_MODULE_MAP.keys()
667        )
668        (q_node, relu_node, ref_node) = _match_static_pattern(
669            n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0]  # type: ignore[arg-type]
670        )
671        if q_node is None:
672            continue
673        assert ref_node is not None
674        (_, scale_node, zero_point_node, _) = q_node.args
675        ref_module = _get_module(ref_node, modules)
676        ref_class = type(ref_module)
677        assert isinstance(scale_node, Node)
678        assert isinstance(zero_point_node, Node)
679        assert issubclass(ref_class, nn.Module)
680
681        # Step 1: Change this pattern to use the corresponding quantized module
682        # For fused modules, we also check whether the inner module is a reference module
683        # If so, we replace the entire fused module with the corresponding quantized module
684        if ref_class in STATIC_LOWER_FUSED_MODULE_MAP:
685            inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class]
686            if type(ref_module[0]) != inner_ref_class:  # type: ignore[index]
687                continue
688        else:
689            q_class = STATIC_LOWER_MODULE_MAP[ref_class]
690        output_scale = getattr(model, scale_node.target)  # type: ignore[arg-type]
691        output_zero_point = getattr(model, zero_point_node.target)  # type: ignore[arg-type]
692        q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
693        # replace reference module with quantized module
694        parent_name, module_name = _parent_name(ref_node.target)
695        setattr(modules[parent_name], module_name, q_module)
696
697        # Step 2: Reroute around dq_node, and remove q_node and its args
698        assert len(ref_node.args) == 1
699        dq_node = ref_node.args[0]
700        assert isinstance(dq_node, Node)
701        ref_node.replace_input_with(dq_node, dq_node.args[0])  # type: ignore[arg-type]
702        q_node.replace_all_uses_with(ref_node)
703        model.graph.erase_node(q_node)
704        model.graph.erase_node(scale_node)
705        model.graph.erase_node(zero_point_node)
706
707
708def _lower_static_weighted_ref_module_with_two_inputs(
709    model: GraphModule, qconfig_map: Dict[str, QConfigAny]
710):
711    """
712    Traverse the graph and find patterns
713    dequantize   dequantize
714       \\         //
715        ref module
716            \\
717          quantize
718    and replace them with the quantized version of the ref module.
719    """
720    modules = dict(model.named_modules(remove_duplicate=False))
721    nodes = list(model.graph.nodes)
722    for n in model.graph.nodes:
723        #                                            (dequantize \
724        # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize)
725        matching_modules = list(STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP.keys())
726        (q_node, ref_node) = _match_static_pattern_with_two_inputs(
727            n, modules, qconfig_map, matching_modules  # type: ignore[arg-type]
728        )
729        if q_node is None:
730            continue
731        assert ref_node is not None
732        (_, scale_node, zero_point_node, _) = q_node.args
733        ref_module = _get_module(ref_node, modules)
734        ref_class = type(ref_module)
735        assert isinstance(scale_node, Node)
736        assert isinstance(zero_point_node, Node)
737        assert issubclass(ref_class, nn.Module)
738
739        # Step 1: Change this pattern to use the corresponding quantized module
740        # For fused modules, we also check whether the inner module is a reference module
741        # If so, we replace the entire fused module with the corresponding quantized module
742        if ref_class in STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP:
743            inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[
744                ref_class
745            ]
746            if type(ref_module[0]) != inner_ref_class:  # type: ignore[index]
747                continue
748        else:
749            continue
750        output_scale = getattr(model, scale_node.target)  # type: ignore[arg-type]
751        output_zero_point = getattr(model, zero_point_node.target)  # type: ignore[arg-type]
752        q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
753        # replace reference module with quantized module
754        parent_name, module_name = _parent_name(ref_node.target)
755        setattr(modules[parent_name], module_name, q_module)
756
757        # Step 2: Reroute around dq_node, and remove q_node and its args
758        assert len(ref_node.args) == 2
759        for arg in ref_node.args:
760            if not is_dequantize_node(arg):
761                continue
762            dq_node = arg
763            assert isinstance(dq_node, Node)
764            ref_node.replace_input_with(dq_node, dq_node.args[0])  # type: ignore[arg-type]
765
766        q_node.replace_all_uses_with(ref_node)
767        model.graph.erase_node(q_node)
768        model.graph.erase_node(scale_node)
769        model.graph.erase_node(zero_point_node)
770
771
772def _lower_dynamic_weighted_ref_module(model: GraphModule):
773    """
774    Traverse the graph and find quantize_per_tensor_dynamic - dequantize - ref_module patterns
775    and replace them with the dynamically quantized version of the ref module.
776    """
777    named_modules = dict(model.named_modules(remove_duplicate=False))
778    for n in model.graph.nodes:
779        if n.op != "call_module" or type(named_modules[str(n.target)]) not in set(
780            DYNAMIC_LOWER_MODULE_MAP.keys()
781        ).union(set(DYNAMIC_LOWER_FUSED_MODULE_MAP.keys())):
782            continue
783        ref_node = n
784        dq_node = ref_node.args[0]
785        if dq_node.op != "call_method" or dq_node.target != "dequantize":
786            continue
787
788        input_dynamic_q_node = dq_node.args[0]
789
790        if (
791            input_dynamic_q_node.op != "call_function"
792            or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic
793        ):
794            continue
795
796        activation_dtype = input_dynamic_q_node.args[1]
797        is_fp16 = activation_dtype == torch.float16
798        is_int8 = activation_dtype in [torch.quint8, torch.qint8]
799        if not is_int8 and not is_fp16:
800            continue
801
802        ref_module = named_modules[str(ref_node.target)]
803        ref_class = type(ref_module)
804        if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP:
805            inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class]
806            if type(ref_module[0]) != inner_ref_class:
807                continue
808        else:
809            q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class)  # type: ignore[assignment]
810        # TODO: maybe define a WeightedDynamicallyQuantizedModule
811        q_module = q_class.from_reference(ref_module)  # type: ignore[attr-defined]
812
813        # replace reference module with dynamically quantized module
814        parent_name, module_name = _parent_name(ref_node.target)
815        setattr(named_modules[parent_name], module_name, q_module)
816        ref_node.replace_input_with(dq_node, input_dynamic_q_node.args[0])
817
818
819def _lower_weight_only_weighted_ref_module(model: GraphModule):
820    """
821    Traverse the graph and find ref_module patterns
822    and replace them with the weight only quantized version of the ref module.
823    """
824    named_modules = dict(model.named_modules(remove_duplicate=False))
825    for n in model.graph.nodes:
826        if n.op != "call_module" or type(named_modules[str(n.target)]) not in set(
827            WEIGHT_ONLY_LOWER_MODULE_MAP.keys()
828        ):
829            continue
830        ref_node = n
831        ref_module = named_modules[str(ref_node.target)]
832        ref_class = type(ref_module)
833        q_class = WEIGHT_ONLY_LOWER_MODULE_MAP.get(ref_class)
834        # TODO: WeightedQuantizedModule is currently assuming static quant apis
835        # with output_scale, output_zero_point in from_reference, we may want to
836        # relax that, or rename this
837        # TODO: maybe define a WeightedWeightOnlyQuantizedModule
838        q_module = q_class.from_reference(ref_module)  # type: ignore[union-attr]
839
840        # replace reference module with dynamically quantized module
841        parent_name, module_name = _parent_name(ref_node.target)
842        setattr(named_modules[parent_name], module_name, q_module)
843
844
845def _lower_static_weighted_ref_functional(
846    model: GraphModule, qconfig_map: Dict[str, QConfigAny]
847):
848    """
849    Traverse the graph and replace functional reference patterns with their quantized versions.
850    """
851    modules = dict(model.named_modules(remove_duplicate=False))
852    nodes = list(model.graph.nodes)
853    for n in model.graph.nodes:
854        # Step 0: Find nodes that match this pattern (dequantize - functional op - quantize)
855        matching_ops = list(STATIC_LOWER_FUNCTIONAL_MAP.keys())
856        (q_node, relu_node, func_node) = _match_static_pattern(
857            n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1]
858        )
859        if q_node is None:
860            continue
861        assert func_node is not None
862        (_, output_scale_node, output_zp_node, _) = q_node.args
863        (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
864        assert isinstance(output_zp_node, Node)
865        assert isinstance(input_dq_node, Node)
866        assert isinstance(weight_dq_node, Node)
867        quantized_weight = weight_dq_node.args[0]
868        assert isinstance(quantized_weight, Node)
869        if quantized_weight.op != "call_function" or quantized_weight.target not in (
870            torch.quantize_per_tensor,
871            torch.quantize_per_channel,
872        ):
873            continue
874
875        # Step 1: Replace quantized weights with packed weights, which will be folded later
876        # Use the right prepack op and prepare the corresponding args
877        # Linear prepack args: (quantized weights[, bias])
878        # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups])
879        prepack_args = [quantized_weight] + remaining_func_args
880        if func_node.target == F.linear:
881            weight_dtype = quantized_weight.args[-1]
882            prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
883        elif func_node.target in CONV_FUNCTIONAL_OPS:
884            prepack_op = get_qconv_prepack_op(func_node.target)  # type: ignore[arg-type]
885            # For conv1d, the stride, padding, and dilation args may be ints,
886            # in which case we need to convert them to tuples
887            if func_node.target == F.conv1d:
888                for i in [2, 3, 4]:
889                    if len(prepack_args) > i and isinstance(prepack_args[i], int):
890                        prepack_args[i] = (prepack_args[i],)
891        elif func_node.target in CONV_TRANSPOSE_FUNCTIONAL_OPS:
892            prepack_op = get_qconv_prepack_op(func_node.target)  # type: ignore[arg-type]
893            # For conv_transpose1d, the stride, padding, and dilation args may be ints,
894            # in which case we need to convert them to tuples
895            if func_node.target == F.conv_transpose1d:
896                # Note prepack_args[5] is groups.
897                for i in [2, 3, 4, 6]:
898                    if len(prepack_args) > i and isinstance(prepack_args[i], int):
899                        prepack_args[i] = (prepack_args[i],)
900            # swap dilation and groups
901            # prepack op has arguments: {w, b, stride, padding, output_padding, dilation, groups}
902            # transposed conv op has arguments: {x, w, b, stride, padding, output_padding, groups, dilation}
903            if len(prepack_args) > 6:
904                prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5]
905        else:
906            raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
907        with model.graph.inserting_before(output_scale_node):  # type: ignore[arg-type]
908            # kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack)
909            # They are not needed for compute op (i.e., quantized::linear)
910            kwargs = func_node.kwargs
911            # F.linear uses 'bias' key for bias while qlinear_prepack uses 'B' for bias
912            if func_node.target == F.linear and "bias" in kwargs:
913                kwargs = kwargs.copy()
914                kwargs["B"] = kwargs["bias"]
915                del kwargs["bias"]
916            packed_weight = model.graph.create_node(
917                "call_function", prepack_op, tuple(prepack_args), kwargs
918            )
919
920        # Step 2: Replace reference pattern with the corresponding quantized op
921        (q_func, q_relu_func) = STATIC_LOWER_FUNCTIONAL_MAP[func_node.target]  # type: ignore[index]
922        # conv_transpose does not support fusion with relu yet. q_relu_func is None in such cases
923        if q_relu_func is not None:
924            func_node.target = q_relu_func if relu_node is not None else q_func
925        else:
926            func_node.target = q_func
927        func_node.args = (
928            input_dq_node.args[0],
929            packed_weight,
930            output_scale_node,
931            output_zp_node,
932        )
933        # kwargs for func_node has been moved to kwargs for prepack op
934        func_node.kwargs = {}
935        q_node.replace_all_uses_with(func_node)
936        # Move func_node after output_zp_node in the graph
937        output_zp_node.append(func_node)
938
939        # Clean up: Remove quantize node, and the relu node if it exists
940        model.graph.erase_node(q_node)
941        if relu_node is not None and q_relu_func is not None:
942            model.graph.erase_node(relu_node)
943
944
945def _lower_dynamic_weighted_ref_functional(
946    model: GraphModule, qconfig_map: Dict[str, QConfigAny]
947):
948    """
949    Traverse the graph and replace functional reference patterns with their dynamically
950    quantized versions.
951    Examples:
952    quantize_per_tensor_dynamic - dequantize - functional linear --> linear_dynamic
953    to(torch.float16) - dequantize - functional linear --> linear_dynamic_fp16
954    """
955    modules = dict(model.named_modules(remove_duplicate=False))
956    nodes = list(model.graph.nodes)
957    # we want to search in reserved order so that we can match the larger patterns first
958    # e.g. we want to match linear - relu before linear.
959    for n in reversed(model.graph.nodes):
960        # Step 0: Find nodes that match this pattern
961        # (quantize_per_tensor_dynamic - dequantize - dynamically quantized op)
962        # We search for the pattern backwards, starting with the quantize node
963        # Quantize node args: (func, scale, zp, dtype)
964        func_node = n
965        # Handle cases where the functional op is wrapped in a ReLU
966        if (
967            func_node.op == "call_function"
968            and func_node.target == F.relu
969            or func_node.op == "call_module"
970            and type(modules[str(func_node.target)]) == torch.nn.ReLU
971        ):
972            relu_node = func_node
973            func_node = relu_node.args[0]
974        else:
975            relu_node = None
976        if should_skip_lowering(func_node, qconfig_map):
977            continue
978        # Linear args: (dequantized inputs, dequantized weights[, bias])
979        # Conv args: (dequantized inputs, dequantized weights[, bias, stride, padding, dilation, groups])
980        if (
981            func_node.op != "call_function"
982            or func_node.target not in DYNAMIC_LOWER_FUNCTIONAL_MAP
983        ):
984            continue
985        (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
986        if (
987            input_dq_node.op != "call_method"
988            or input_dq_node.target != "dequantize"
989            or weight_dq_node.op != "call_method"
990            or weight_dq_node.target != "dequantize"
991        ):
992            continue
993
994        input_dynamic_q_node = input_dq_node.args[0]
995
996        if (
997            input_dynamic_q_node.op != "call_function"
998            or input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic
999        ):
1000            continue
1001
1002        reduce_range_node = None
1003        (pattern_input, activation_dtype, reduce_range_node) = input_dynamic_q_node.args
1004        is_fp16 = activation_dtype == torch.float16
1005        is_int8 = activation_dtype in [torch.quint8, torch.qint8]
1006        if not is_int8 and not is_fp16:
1007            continue
1008
1009        quantized_weight = weight_dq_node.args[0]
1010        weight_dtype = quantized_weight.args[-1]
1011
1012        # Step 1: Try to select reference pattern with the corresponding quantized op
1013        dynamic_quant_dtype_key = (activation_dtype, weight_dtype)
1014        if (
1015            dynamic_quant_dtype_key
1016            not in DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target]
1017        ):
1018            print(
1019                f"Didn't find dtype combination {dynamic_quant_dtype_key} during "
1020                f"dynamic quantized op lowering for {func_node.target}"
1021            )
1022            continue
1023        (q_func, q_relu_func) = DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target][
1024            dynamic_quant_dtype_key
1025        ]
1026
1027        if q_func is None or q_relu_func is None:
1028            print(
1029                "Didn't find corresponding quantized function or quantized relu function "
1030                f"for {func_node.target}, {dynamic_quant_dtype_key}"
1031            )
1032            continue
1033
1034        # Step 2: Replace quantized weights with packed weights, which will be folded later
1035        # Use the right prepack op and prepare the corresponding args
1036        # Linear prepack args: (quantized weights[, bias])
1037        # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups])
1038        prepack_args = [quantized_weight] + remaining_func_args
1039        prepack_kwargs = {}
1040        if func_node.target == F.linear:
1041            prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
1042            kwargs = func_node.kwargs.copy()
1043            if "bias" in kwargs:
1044                prepack_kwargs["B"] = kwargs["bias"]
1045                del kwargs["bias"]
1046                func_node.kwargs = kwargs
1047        elif func_node.target in CONV_FUNCTIONAL_OPS:
1048            prepack_op = get_qconv_prepack_op(func_node.target)
1049            # For conv1d, the stride, padding, and dilation args may be ints,
1050            # in which case we need to convert them to tuples
1051            if func_node.target == F.conv1d:
1052                for i in [2, 3, 4]:
1053                    if len(prepack_args) > i and isinstance(prepack_args[i], int):
1054                        prepack_args[i] = (prepack_args[i],)
1055        else:
1056            raise ValueError(f"Lowering is not supported for op '{func_node.target}'")
1057        with model.graph.inserting_before(func_node):
1058            packed_weight = model.graph.create_node(
1059                "call_function", prepack_op, tuple(prepack_args), prepack_kwargs
1060            )
1061
1062        # Step 3: Replace reference pattern with the corresponding quantized op
1063        func_node.target = q_relu_func if relu_node is not None else q_func
1064        if is_int8:
1065            func_node.args = (pattern_input, packed_weight, reduce_range_node)
1066        else:
1067            func_node.args = (pattern_input, packed_weight)
1068
1069        if relu_node is not None:
1070            relu_node.replace_all_uses_with(func_node)
1071
1072        # Step 4: Remove the relu node if it exists
1073        if relu_node is not None:
1074            model.graph.erase_node(relu_node)
1075
1076
1077def _lower_quantized_binary_op(model: GraphModule, qconfig_map: Dict[str, QConfigAny]):
1078    binary_ops_to_lower: List[Callable] = [
1079        operator.add,
1080        torch.add,
1081        operator.mul,
1082        torch.mul,
1083        torch.matmul,
1084    ]
1085    modules = dict(model.named_modules(remove_duplicate=False))
1086    for n in model.graph.nodes:
1087        # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize)
1088        (q_node, relu_node, bop_node) = _match_static_pattern(
1089            n,
1090            modules,
1091            qconfig_map,
1092            binary_ops_to_lower,
1093            dequantize_node_arg_indices=[0, 1],
1094        )
1095        if q_node is None:
1096            continue
1097        assert bop_node is not None
1098        (_, scale_node, zero_point_node, _) = q_node.args
1099
1100        # Step 1: Remove dequant nodes
1101        num_dq_nodes = 0
1102        for arg in bop_node.args:
1103            if not is_dequantize_node(arg):
1104                continue
1105            dq_node = arg
1106            assert isinstance(dq_node, Node)
1107            dn_input = dq_node.args[0]
1108            bop_node.replace_input_with(dq_node, dn_input)  # type: ignore[arg-type]
1109            num_dq_nodes += 1
1110        assert num_dq_nodes > 0
1111
1112        # Step 2: Swap binary op to quantized binary op
1113        assert bop_node.target in QBIN_OP_MAPPING
1114        binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING
1115        qbin_op = binop_to_qbinop[bop_node.target]
1116        # prepare the args for quantized binary op
1117        # (x, y)
1118        qop_node_args = list(bop_node.args)
1119        # (x, y, scale, zero_point)
1120        # add scale and zero_point arguments for Tensor - Tensor operation
1121        if num_dq_nodes == 2:
1122            qop_node_args.extend([scale_node, zero_point_node])
1123        # insert a call to quantized binary op and remove the original binary op
1124        with model.graph.inserting_after(q_node):
1125            qop_node = create_node_from_old_node_preserve_meta(
1126                model.graph,
1127                ("call_function", qbin_op, tuple(qop_node_args), {}),
1128                bop_node,
1129            )
1130            q_node.replace_all_uses_with(qop_node)
1131
1132        # Step 3: Remove quantize node, binary op node, and relu node if any
1133        model.graph.erase_node(q_node)
1134        if relu_node is not None:
1135            model.graph.erase_node(relu_node)
1136        model.graph.erase_node(bop_node)
1137
1138
1139def special_pattern_replacement(model: GraphModule):
1140    modules = dict(model.named_modules(remove_duplicate=False))
1141    for n in model.graph.nodes:
1142        q_node = n
1143        is_quantize = q_node.target == torch.quantize_per_tensor
1144        is_to_fp16 = (
1145            q_node.op == "call_method"
1146            and q_node.target == "to"
1147            and len(q_node.args) == 2
1148            and q_node.args[1] == torch.float16
1149        )
1150        if not (is_quantize or is_to_fp16):
1151            continue
1152        ref_node = q_node.args[0]
1153        # get output scale/zero_point/dtype from the quantize node
1154        # ref_node, scale_node, zero_point_node, dtype = q_node.args
1155        # TODO: add safety checks that users for the ref_node and dq_node needs to be one
1156        is_call_function, is_call_method, is_call_module = is_fixed_qparams_node(
1157            ref_node, modules
1158        )
1159        if is_to_fp16 and (is_call_function or is_call_method or is_call_module):
1160            # TODO: add a warning or error out here? (bc-breaking if error out)
1161            # warnings.warn(
1162            #     "Only reference patterns are currently supported for {dtype} dtype with {op} op"
1163            #     "".format(dtype=dtypes, op=ref_node))
1164            continue
1165
1166        is_call_function, is_call_method, is_call_module = is_default_node(
1167            ref_node, modules
1168        )
1169        if is_to_fp16 and (is_call_function or is_call_method or is_call_module):
1170            # TODO: add a warning or error out here? (bc-breaking if error out)
1171            continue
1172
1173        # This check includes all supported ops
1174        is_call_function, is_call_method, is_call_module = is_special_pattern_node(
1175            ref_node, modules
1176        )
1177        if not (is_call_module or is_call_function or is_call_method):
1178            continue
1179        assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0
1180        dq_node_or_nodes = (
1181            ref_node.args[0]
1182            if len(ref_node.args) > 0
1183            else next(iter(ref_node.kwargs.values()))
1184        )
1185        assert isinstance(dq_node_or_nodes, (Node, tuple, list))
1186        is_dequantize = False
1187        if isinstance(dq_node_or_nodes, Node):
1188            is_dequantize = (
1189                dq_node_or_nodes.op == "call_method"
1190                and dq_node_or_nodes.target == "dequantize"
1191            )
1192        elif isinstance(dq_node_or_nodes, (tuple, list)):
1193            is_dequantize = all(
1194                x.op == "call_method" and x.target == "dequantize"
1195                for x in dq_node_or_nodes
1196            )
1197
1198        if not is_dequantize:
1199            continue
1200
1201        # TODO: enable we have patterns that needs to swap the modules
1202        if is_call_module:
1203            ref_module = modules[ref_node.target]
1204            if type(ref_module) in SPECIAL_PATTERN_LOWER_MODULE_MAP and is_quantize:
1205                qmodule_cls = SPECIAL_PATTERN_LOWER_MODULE_MAP.get(type(ref_module))
1206                scale_node = q_node.args[1]
1207                zero_point_node = q_node.args[2]
1208                output_scale = getattr(model, scale_node.target)
1209                output_zero_point = getattr(model, zero_point_node.target)
1210
1211                qmodule = qmodule_cls.from_reference(  # type:ignore[union-attr]
1212                    ref_module, output_scale, output_zero_point
1213                )
1214                # replace reference module with quantized module
1215                parent_name, module_name = _parent_name(ref_node.target)
1216                setattr(modules[parent_name], module_name, qmodule)
1217
1218        # reroute around dq node:
1219        dq_nodes: List[Node] = []
1220        if isinstance(dq_node_or_nodes, Node):
1221            dq_nodes = [dq_node_or_nodes]
1222        elif isinstance(dq_node_or_nodes, (tuple, list)):
1223            dq_nodes = list(dq_node_or_nodes)
1224
1225        for dq_node in dq_nodes:
1226            dn_input = dq_node.args[0]
1227            ref_node.replace_input_with(dq_node, dn_input)
1228
1229        # store q node args
1230        qnode_qparams = list(q_node.args)[1:]
1231        # replace uses of q node with input and remove q node
1232        q_node_input = q_node.args[0]
1233        q_node.replace_all_uses_with(q_node_input)
1234        model.graph.erase_node(q_node)
1235
1236        is_call_function, is_call_method, is_call_module = is_default_node(
1237            ref_node, modules
1238        )
1239        if is_call_function:
1240            # pass scale/zer_point arguments from quantize_per_tensor to the default node operator
1241            # insert an op after the zero_point node so that the scale/zero_point
1242            # nodes are is available
1243            qop = get_quantized_operator(ref_node.target)
1244            args = list(ref_node.args)
1245            kwargs = dict(ref_node.kwargs)
1246            if qop in QOP_TO_ARG_NAMES_TO_SKIP:
1247                args_to_skip = QOP_TO_ARG_NAMES_TO_SKIP[qop]
1248                for arg in args_to_skip:
1249                    if arg in kwargs:
1250                        kwargs.pop(arg)
1251            kwargs["output_scale"] = qnode_qparams[0]
1252            kwargs["output_zero_point"] = qnode_qparams[1]
1253            with model.graph.inserting_after(qnode_qparams[1]):
1254                qop_node = create_node_from_old_node_preserve_meta(
1255                    model.graph, ("call_function", qop, tuple(args), kwargs), ref_node
1256                )
1257                ref_node.replace_all_uses_with(qop_node)
1258                model.graph.erase_node(ref_node)
1259        else:
1260            # remove scale/zero_point node for quantize node
1261            for n in qnode_qparams:
1262                if isinstance(n, Node):
1263                    model.graph.erase_node(n)
1264
1265    return model
1266
1267
1268def _lower_getattr_tensor_metadta_op(model: GraphModule):
1269    """Modified the graph of the model inplace, to skip extra dequantize op before
1270    the general tensor shape ops when possible
1271    """
1272    for n in model.graph.nodes:
1273        if is_getattr_tensor_metadata_node(n):
1274            maybe_dq = n.args[0]
1275            if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize":
1276                continue
1277            # skip the dequantize node
1278            args = list(n.args)
1279            args[0] = n.args[0].args[0]
1280            n.args = tuple(args)
1281
1282
1283def _lower_get_tensor_info_op(model: GraphModule):
1284    """Modified the graph of the model inplace, to skip extra dequantize op before
1285    the general tensor shape ops when possible
1286    """
1287    for n in model.graph.nodes:
1288        if not is_get_tensor_info_node(n):
1289            continue
1290        maybe_dq = n.args[0]
1291        if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize":
1292            continue
1293        # skip the dequantize node
1294        args = list(n.args)
1295        args[0] = n.args[0].args[0]
1296        n.args = tuple(args)
1297
1298
1299def _lower_to_native_backend(
1300    model: GraphModule,
1301    qconfig_map: Dict[str, QConfigAny],
1302    node_name_to_scope: Dict[str, Tuple[str, type]],
1303) -> GraphModule:
1304    """Lower a quantized reference model (with reference quantized operator patterns)
1305    to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same
1306    operator signature so they can be lowered with the same function
1307    """
1308    _lower_static_weighted_ref_module(model, qconfig_map)
1309    _lower_static_weighted_ref_module_with_two_inputs(model, qconfig_map)
1310    _lower_dynamic_weighted_ref_module(model)
1311    _lower_weight_only_weighted_ref_module(model)
1312    _lower_static_weighted_ref_functional(model, qconfig_map)
1313    _lower_dynamic_weighted_ref_functional(model, qconfig_map)
1314    _lower_quantized_binary_op(model, qconfig_map)
1315    _lower_getattr_tensor_metadta_op(model)
1316    _lower_get_tensor_info_op(model)
1317    special_pattern_replacement(model)
1318    model.graph.eliminate_dead_code()
1319    model = fold_weight(model, node_name_to_scope)
1320    model.graph.eliminate_dead_code()
1321    model.recompile()
1322    model.graph.lint()
1323    return model
1324