xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/convert.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import copy
4import operator
5import warnings
6from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
7
8import torch
9from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY
10from torch.ao.quantization.backend_config import (
11    BackendConfig,
12    get_native_backend_config,
13)
14from torch.ao.quantization.backend_config.utils import (
15    get_fused_module_classes,
16    get_pattern_to_dtype_configs,
17    get_qat_module_classes,
18    get_root_module_to_quantized_reference_module,
19)
20from torch.ao.quantization.observer import _is_activation_post_process
21from torch.ao.quantization.qconfig import qconfig_equals, QConfigAny
22from torch.ao.quantization.qconfig_mapping import QConfigMapping
23from torch.ao.quantization.quant_type import QuantType
24from torch.ao.quantization.quantize import _remove_qconfig
25from torch.ao.quantization.stubs import DeQuantStub
26from torch.ao.quantization.utils import (
27    _parent_name,
28    activation_is_statically_quantized,
29    get_qparam_dict,
30    get_swapped_custom_module_class,
31    is_per_channel,
32    to_underlying_dtype,
33    weight_is_quantized,
34)
35from torch.fx import GraphModule
36from torch.fx.graph import Argument, Graph, Node
37from torch.nn.utils.parametrize import type_before_parametrizations
38
39# importing the lib so that the quantized_decomposed ops are registered
40from ._decomposed import quantized_decomposed_lib  # noqa: F401
41from ._equalize import convert_eq_obs, update_obs_for_equalization
42from .custom_config import ConvertCustomConfig, PrepareCustomConfig
43from .graph_module import _is_observed_module, _is_observed_standalone_module
44from .lower_to_fbgemm import lower_to_fbgemm
45from .qconfig_mapping_utils import (
46    _compare_prepare_convert_qconfig_mappings,
47    _generate_node_name_to_qconfig,
48    _is_qconfig_supported_by_dtype_configs,
49    _update_qconfig_for_fusion,
50    _update_qconfig_for_qat,
51)
52from .utils import (
53    _get_module,
54    _is_custom_module_lstm,
55    _is_custom_module_mha,
56    assert_and_get_unique_device,
57    collect_producer_nodes,
58    create_getattr_from_value,
59    get_custom_module_class_keys,
60    graph_module_from_producer_nodes,
61    node_arg_is_weight,
62)
63
64
65__all__ = [
66    "convert",
67    "convert_custom_module",
68    "convert_standalone_module",
69    "convert_weighted_module",
70]
71
72SUPPORTED_QDTYPES = [
73    torch.quint8,
74    torch.qint8,
75    torch.qint32,
76    torch.uint8,
77    torch.int8,
78    torch.int16,
79    torch.int32,
80    torch.float8_e5m2,
81    torch.float8_e4m3fn,
82]
83
84_QSCHEME_TO_CHOOSE_QPARAMS_OP = {
85    torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
86    torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
87}
88
89
90def _replace_observer_with_quantize_dequantize_node_decomposed(
91    model: torch.fx.GraphModule,
92    node: Node,
93    modules: Dict[str, torch.nn.Module],
94    node_name_to_scope: Dict[str, Tuple[str, type]],
95    node_name_to_qconfig: Dict[str, QConfigAny],
96) -> None:
97    """Replace activation_post_process module call node with quantize and
98    dequantize node working with decomposed Tensor
99
100    Before:
101    ... -> observer_0(x) -> ...
102    After:
103    ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) ->
104    torch.ops.quantized_decomposed.dequantize_per_tensor() -> ...
105
106    or quantize_per_channel and dequantize_per_channel
107    """
108    graph = model.graph
109    assert modules is not None
110    assert isinstance(node.target, str)
111    module_path, prefix = _get_module_path_and_prefix(
112        node, node_name_to_scope, node_name_to_qconfig
113    )
114    activation_post_process = modules[node.target]
115    if hasattr(activation_post_process, "convert"):
116        activation_post_process.convert(model, node)
117        return
118    # skip replacing observers to quant/dequant nodes if the qconfigs of all
119    # consumers and producers of this observer are None
120    skip_replacement = all(
121        _has_none_qconfig(n, node_name_to_qconfig)
122        for n in list(node.args) + list(node.users.keys())
123    )
124    if skip_replacement or not _is_conversion_supported(activation_post_process):
125        # didn't find corresponding quantize op and info for the activation_post_process
126        # so we just remove the observer
127        with graph.inserting_before(node):
128            node.replace_all_uses_with(node.args[0])
129            graph.erase_node(node)
130        return
131
132    # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
133
134    # 1. extract the information from activation_post_process module for generating
135    # the quantize and dequantize operator
136    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
137
138    is_dynamic = False
139    if hasattr(activation_post_process, "is_dynamic"):
140        is_dynamic = activation_post_process.is_dynamic  # type: ignore[assignment]
141
142    def add_dequantize_op_kwargs(dequantize_op, input_node):
143        dequantize_op_kwargs = {}
144        if "val" in input_node.meta:
145            dq_out_dtype = input_node.meta["val"].dtype
146            if dq_out_dtype != torch.float32:
147                dequantize_op_kwargs = {"out_dtype": dq_out_dtype}
148        return dequantize_op_kwargs
149
150    if dtype in SUPPORTED_QDTYPES and (not is_dynamic):
151        # TODO: probably should cleanup this condition check, it's hard
152        # to reason about this if and the following elif
153
154        # uint8/int8/int32 static quantization branch
155
156        # 1. extract information for inserting q/dq node from activation_post_process
157        node_type = "call_function"
158        quantize_op: Optional[Callable] = None
159        scale, zero_point = activation_post_process.calculate_qparams()  # type: ignore[attr-defined, operator]
160        if is_per_channel(activation_post_process.qscheme):  # type: ignore[attr-defined]
161            ch_axis = int(activation_post_process.ch_axis)  # type: ignore[attr-defined, arg-type]
162            quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
163            dequantize_op = (
164                torch.ops.quantized_decomposed.dequantize_per_channel.default
165            )
166            quant_min = activation_post_process.quant_min
167            quant_max = activation_post_process.quant_max
168            dtype_ = to_underlying_dtype(dtype)
169            qparams = {
170                "_scale_": scale,
171                "_zero_point_": zero_point,
172                "_axis_": ch_axis,
173                "_quant_min_": quant_min,
174                "_quant_max_": quant_max,
175                "_dtype_": dtype_,
176            }
177        else:
178            quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
179            dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
180            scale = float(scale)
181            zero_point = int(zero_point)
182            quant_min = activation_post_process.quant_min  # type: ignore[attr-defined]
183            quant_max = activation_post_process.quant_max  # type: ignore[attr-defined]
184            dtype_ = to_underlying_dtype(dtype)
185            qparams = {
186                "_scale_": scale,
187                "_zero_point_": zero_point,
188                "_quant_min_": quant_min,
189                "_quant_max_": quant_max,
190                "_dtype_": dtype_,
191            }
192
193        # 2. replace activation_post_process node with quantize and dequantize
194        with graph.inserting_before(node):
195            input_node = node.args[0]
196            quantize_op_inputs = [input_node]
197            for key, value_or_node in qparams.items():
198                # TODO: we can add the information of whether a value needs to
199                # be registered as an attribute in qparams dict itself
200                if key in ["_scale_", "_zero_point_"] and (
201                    not isinstance(value_or_node, (float, int))
202                ):
203                    # For scale and zero_point values we register them as buffers in the root module.
204                    # However, note that when the values are not tensors, as in the case of
205                    # per_tensor quantization, they will be treated as literals.
206                    # However, registering them as a node seems to cause issue with dynamo
207                    # tracing where it may consider tensor overload as opposed to default.
208                    # With extra check of scale and zero_point being scalar, it makes
209                    # sure that the default overload can be used.
210                    # TODO: maybe need more complex attr name here
211                    qparam_node = create_getattr_from_value(
212                        model, graph, module_path + prefix + key, value_or_node
213                    )
214                    quantize_op_inputs.append(qparam_node)
215                else:
216                    # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
217                    quantize_op_inputs.append(value_or_node)
218
219            quantized_node = graph.create_node(
220                node_type, quantize_op, tuple(quantize_op_inputs), {}
221            )
222            # use the same qparams from quantize op
223            dq_inputs = [quantized_node] + quantize_op_inputs[1:]
224            dequantized_node = graph.call_function(
225                dequantize_op,
226                tuple(dq_inputs),
227                add_dequantize_op_kwargs(dequantize_op, input_node),
228            )
229
230            node.replace_all_uses_with(dequantized_node)
231            # propagate numeric debug handle from observer/fake_quant node to dequantize node
232            if (
233                CUSTOM_KEY in node.meta
234                and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
235            ):
236                if CUSTOM_KEY not in dequantized_node.meta:
237                    dequantized_node.meta[CUSTOM_KEY] = {}
238                dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
239                    CUSTOM_KEY
240                ][NUMERIC_DEBUG_HANDLE_KEY]
241            graph.erase_node(node)
242    elif is_dynamic:
243        # uint8/int8/fp16 dynamic quantization
244
245        # 1. extract information for inserting q/dq node from activation_post_process
246        node_type = "call_function"
247        quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
248        # we only use choose_qparams for is_decomposed now,
249        # but we should probably align the non-decomposed path with this as well,
250        # and that can be done after we remove reduce_range flag
251        # 1. extract qparams from activation_post_process module
252        dtype_ = to_underlying_dtype(dtype)
253        assert dtype_ in [torch.uint8, torch.int8], (
254            "only uint8 and int8 are supported in reference flow for "
255            "dynamic quantization right now"
256        )
257        quant_min = activation_post_process.quant_min  # type: ignore[attr-defined]
258        quant_max = activation_post_process.quant_max  # type: ignore[attr-defined]
259        qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine)  # type: ignore[attr-defined]
260        eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps)  # type: ignore[attr-defined]
261        # note: scale and zero_point are missing for quantize_per_tensor op
262        # we'll need to get this from choose_qparams op, which we'll add after
263        # this step
264        qparams = {
265            "_quant_min_": quant_min,
266            "_quant_max_": quant_max,
267            "_eps_": eps,
268            "_dtype_": dtype_,
269        }
270
271        choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme]
272        # 2. insert choose_qparams op and update the qparams list
273        with graph.inserting_before(node):
274            input_node = node.args[0]
275            choose_qparams_op_inputs = [node.args[0]]
276            for key, value in qparams.items():
277                # we have quant_min, quant_max and dtype, all should be stored
278                # as literals
279                choose_qparams_op_inputs.append(value)
280            choose_qparams_node = graph.create_node(
281                "call_function", choose_qparams_op, tuple(choose_qparams_op_inputs), {}
282            )
283            # choose_qparms returns (scale, zero_point)
284            scale_node = graph.create_node(
285                "call_function", operator.getitem, (choose_qparams_node, 0), {}
286            )
287            zero_point_node = graph.create_node(
288                "call_function", operator.getitem, (choose_qparams_node, 1), {}
289            )
290            quant_min = qparams["_quant_min_"]
291            quant_max = qparams["_quant_max_"]
292            dtype = qparams["_dtype_"]
293            qparams = {
294                "_scale_": scale_node,
295                "_zero_point_": zero_point_node,
296                "_quant_min_": quant_min,
297                "_quant_max_": quant_max,
298                "_dtype_": dtype,
299            }
300
301        # 3. replace activation_post_process node to quantize and dequantize node
302        with graph.inserting_before(node):
303            input_node = node.args[0]
304            quantize_op_inputs = [input_node]
305            for key, value_or_node in qparams.items():
306                # TODO: we can add the information of whether a value needs to
307                # be registered as an attribute in qparams dict itself
308                if key in ["_scale_", "_zero_point_"]:
309                    # in this case we have a node in the graph since it's dynamically
310                    # computed from the input, with choose_qparams op
311                    qparam_node = value_or_node
312                    quantize_op_inputs.append(qparam_node)
313                else:
314                    # for qparams that are not scale/zero_point (like axis, dtype) we
315                    # store them as literals in the graph.
316                    quantize_op_inputs.append(value_or_node)
317
318            quantized_node = graph.create_node(
319                node_type, quantize_op, tuple(quantize_op_inputs), {}
320            )
321            # use the same qparams from quantize op
322            dq_inputs = [quantized_node] + quantize_op_inputs[1:]
323            # need to use the tensor variant of this op, since scale and zero_point
324            # from choose_qparam are Tensors, instead of float/int, this is to
325            # prevent these nodes being traced away by downstream systems
326            dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
327            dequantized_node = graph.call_function(
328                dequantize_op,
329                tuple(dq_inputs),
330                add_dequantize_op_kwargs(dequantize_op, input_node),
331            )
332
333            def remap_fn(x):
334                return dequantized_node if x is node else x
335
336            node.replace_all_uses_with(dequantized_node)
337            # propagate numeric debug handle from observer/fake_quant node to dequantize node
338            if NUMERIC_DEBUG_HANDLE_KEY in node.meta:
339                dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
340                    NUMERIC_DEBUG_HANDLE_KEY
341                ]
342            graph.erase_node(node)
343    elif dtype == torch.float16:
344        raise NotImplementedError("decomposed to float16 op not implemented yet")
345
346    # should not reach since we have checks in the beginning to make sure the
347    # activation_post_process is supported
348
349
350def _replace_observer_with_quantize_dequantize_node(
351    model: torch.fx.GraphModule,
352    node: Node,
353    modules: Dict[str, torch.nn.Module],
354    node_name_to_scope: Dict[str, Tuple[str, type]],
355    node_name_to_qconfig: Dict[str, QConfigAny],
356) -> None:
357    """Replace activation_post_process module call node with quantize and
358    dequantize node
359
360    Before:
361    ... -> observer_0(x) -> ...
362    After:
363    ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
364    """
365    assert modules is not None
366    assert isinstance(node.target, str)
367    graph = model.graph
368    module_path, prefix = _get_module_path_and_prefix(
369        node, node_name_to_scope, node_name_to_qconfig
370    )
371    activation_post_process = modules[node.target]
372    # skip replacing observers to quant/dequant nodes if the qconfigs of all
373    # consumers and producers of this observer are None
374    skip_replacement = all(
375        _has_none_qconfig(n, node_name_to_qconfig)
376        for n in list(node.args) + list(node.users.keys())
377    )
378    if skip_replacement or not _is_conversion_supported(activation_post_process):
379        # didn't find corresponding quantize op and info for the activation_post_process
380        # so we just remove the observer
381        with graph.inserting_before(node):
382            node.replace_all_uses_with(node.args[0])
383            graph.erase_node(node)
384        return
385
386    # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
387    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
388
389    is_dynamic = False
390    if hasattr(activation_post_process, "is_dynamic"):
391        is_dynamic = activation_post_process.is_dynamic  # type: ignore[attr-defined, assignment]
392
393    if dtype in [
394        torch.quint8,
395        torch.qint8,
396        torch.qint32,
397        torch.float8_e5m2,
398        torch.float8_e4m3fn,
399    ] and (not is_dynamic):
400        # TODO: probably should cleanup this condition check, it's hard
401        # to reason about this if and the following elif
402
403        # uint8/int8/int32 static quantization branch
404
405        # 1. extract the information from activation_post_process module for generating
406        # the quantize and dequantize operator
407        node_type = "call_function"
408        quantize_op: Optional[Callable] = None
409        scale, zero_point = activation_post_process.calculate_qparams()  # type: ignore[attr-defined, operator]
410        if is_per_channel(activation_post_process.qscheme):  # type: ignore[attr-defined]
411            ch_axis = int(activation_post_process.ch_axis)  # type: ignore[attr-defined, arg-type]
412            qparams = {
413                "_scale_": scale,
414                "_zero_point_": zero_point,
415                "_axis_": ch_axis,
416                "_dtype_": dtype,
417            }
418            quantize_op = torch.quantize_per_channel
419        else:
420            scale = float(scale)
421            zero_point = int(zero_point)
422            qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
423            quantize_op = torch.quantize_per_tensor
424
425        # 2. replace activation_post_process node with quantize and dequantize
426        with graph.inserting_before(node):
427            input_node = node.args[0]
428            quantize_op_inputs = [input_node]
429            for key, value_or_node in qparams.items():
430                # TODO: we can add the information of whether a value needs to
431                # be registered as an attribute in qparams dict itself
432                if key in ["_scale_", "_zero_point_"]:
433                    # For scale and zero_point values we register them as buffers in the root module.
434                    # TODO: maybe need more complex attr name here
435                    qparam_node = create_getattr_from_value(
436                        model, graph, module_path + prefix + key, value_or_node
437                    )
438                    quantize_op_inputs.append(qparam_node)
439                else:
440                    # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
441                    quantize_op_inputs.append(value_or_node)
442
443            quantized_node = graph.create_node(
444                node_type, quantize_op, tuple(quantize_op_inputs), {}
445            )
446            dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
447            node.replace_all_uses_with(dequantized_node)
448            graph.erase_node(node)
449    elif is_dynamic:
450        # uint8/int8/fp16 dynamic quantization branch
451
452        node_type = "call_function"
453        quantize_op = torch.quantize_per_tensor_dynamic
454        # TODO: get reduce range from observer
455        # reduce_range = activation_post_process.reduce_range
456        reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
457        qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range}
458
459        with graph.inserting_before(node):
460            input_node = node.args[0]
461            quantize_op_inputs = [input_node]
462            for key, value in qparams.items():
463                quantize_op_inputs.append(value)
464
465            quantized_node = graph.create_node(
466                node_type, quantize_op, tuple(quantize_op_inputs), {}
467            )
468            dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
469            node.replace_all_uses_with(dequantized_node)
470            graph.erase_node(node)
471    elif dtype == torch.float16:
472        node_type = "call_method"
473        quantize_op = "to"  # type: ignore[assignment]
474        qparams = {"_dtype_": dtype}
475        with graph.inserting_before(node):
476            input_node = node.args[0]
477            quantize_op_inputs = [input_node]
478            for key, value in qparams.items():
479                # TODO: we can add the information of whether a value needs to
480                # be registered as an attribute in qparams dict itself
481                quantize_op_inputs.append(value)
482
483            quantized_node = graph.create_node(
484                node_type, quantize_op, tuple(quantize_op_inputs), {}
485            )
486            dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
487            node.replace_all_uses_with(dequantized_node)
488            graph.erase_node(node)
489
490    # should not reach since we have checks in the beginning to make sure the
491    # activation_post_process is supported
492
493
494# this is a temporary hack for custom module, we may want to implement
495# this properly after the custom module class design is finalized
496# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
497# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
498# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
499def _replace_observer_or_dequant_stub_with_dequantize_node(
500    node: Node, graph: Graph
501) -> None:
502    call_custom_module_node = node.args[0]
503    assert isinstance(
504        call_custom_module_node, Node
505    ), f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
506    node.replace_all_uses_with(call_custom_module_node)
507    graph.erase_node(node)
508    _insert_dequantize_node(call_custom_module_node, graph)
509
510
511def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool:
512    dtype = activation_post_process.dtype  # type: ignore[attr-defined]
513
514    is_dynamic = False
515    if hasattr(activation_post_process, "is_dynamic"):
516        is_dynamic = activation_post_process.is_dynamic  # type: ignore[attr-defined, assignment]
517
518    return (
519        (dtype in SUPPORTED_QDTYPES and (not is_dynamic))
520        or is_dynamic  # type: ignore[return-value]
521        or dtype == torch.float16
522    )
523
524
525def _has_none_qconfig(
526    node: Argument, node_name_to_qconfig: Dict[str, QConfigAny]
527) -> bool:
528    """Check if a node has a qconfig of None, i.e. user requested to not quantize
529    the node
530    """
531    return (
532        isinstance(node, Node)
533        and node.name in node_name_to_qconfig
534        and node_name_to_qconfig[node.name] is None
535    )
536
537
538def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None:
539    """Extract the subgraph that produces the weight for dynamic quant
540    or weight only quant node and run the subgraph to observe the weight.
541    Note that the observers of dynamic quant or weight only quant ops are
542    run during the convert step.
543    """
544    for node in observed.graph.nodes:
545        if node.op != "call_function":
546            continue
547        for node_arg in node.args:
548            # node_arg is weight
549            if node_arg and node_arg_is_weight(node, node_arg):
550                weight_observer_nodes = collect_producer_nodes(node_arg)
551                if weight_observer_nodes is None:
552                    continue
553                weight_observer_module = graph_module_from_producer_nodes(
554                    observed, weight_observer_nodes
555                )
556                # run the weight observer
557                weight_observer_module()
558
559
560def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None:
561    """If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
562    we'll recursively remove the dequantize Node
563    """
564    if isinstance(arg, Node) and arg.op == "call_method" and arg.target == "dequantize":
565        quantize_node = arg.args[0]
566        # we only replace the specific use since dequantize could be used by other nodes
567        # as well
568        node.replace_input_with(arg, quantize_node)
569    elif isinstance(arg, (list, tuple)):
570        for arg_element in arg:
571            _maybe_recursive_remove_dequantize(arg_element, node, graph)
572    elif isinstance(arg, dict):
573        for arg_element in arg.values():
574            _maybe_recursive_remove_dequantize(arg_element, node, graph)
575    else:
576        warnings.warn(
577            f"Unsupported node type in recursive remove dequantize: {type(arg)}"
578        )
579
580
581def _get_module_path_and_prefix(
582    obs_node: Node,
583    node_name_to_scope: Dict[str, Tuple[str, type]],
584    node_name_to_qconfig: Dict[str, QConfigAny],
585) -> Tuple[str, str]:
586    """Given and observer node, get the `Scope` or the fully qualified name for
587    the submodule containing the observed node, also return a prefix of "_input"
588    when the observed node is an input of a F.linear op, and not the output of another
589    quantized op.
590    TODO: this logic is hacky, we should think about how to remove it or make it more
591    general
592    """
593    observed_node = obs_node.args[0]
594    # an observer can be inserted for both input of the next operator or output of the previous
595    # operator (they can be the same)
596    # this flag identifies if the observer is inserted only because the observed node is
597    # the input of the next operator
598    assert isinstance(
599        observed_node, Node
600    ), f"Expecting observed node to be a Node, but got {observed_node}"
601    is_input_observer_only = (
602        node_name_to_qconfig[observed_node.name] is None
603        if observed_node.name in node_name_to_qconfig
604        else None
605    )
606    if is_input_observer_only:
607        # if the quantize function is at the input of op, then we find the first user of the observer_node
608        # to get the path. If a linear call_function is in the user list, we return the first instance
609        # of linear node to get the FQN.
610        users = list(obs_node.users)
611        first_linear_use_or_first_use = users[0] if users else None
612        linear_node = None
613        for n in users:
614            if n.op == "call_function" and n.target == torch.nn.functional.linear:
615                linear_node = n
616                break
617        if linear_node:
618            first_linear_use_or_first_use = linear_node
619        prefix = "_input"
620    else:
621        # if the quantize function is at the output of the op, we use the observer input node to get the path
622        first_linear_use_or_first_use = observed_node
623        prefix = ""
624
625    if (
626        first_linear_use_or_first_use
627        and first_linear_use_or_first_use.name in node_name_to_scope
628    ):
629        module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
630    else:
631        # TODO: it's not used, so actually we can skip quantization
632        # but this requires changing return type of quantize_node
633        # we can fix it later if needed
634        module_path = ""
635    return module_path, prefix
636
637
638def _insert_dequantize_node(node: Node, graph: Graph) -> None:
639    """Inserts dequantize node for `node` in `graph`"""
640    with graph.inserting_after(node):
641        dequantize_node = graph.call_method("dequantize", (node,))
642        for user_node in dict(node.users):
643            if user_node is not dequantize_node:
644                user_node.replace_input_with(node, dequantize_node)
645
646
647def _maybe_get_observer_for_node(
648    node: Node, modules: Dict[str, torch.nn.Module]
649) -> Optional[torch.nn.Module]:
650    """
651    If the node is observed, return the observer
652    instance. Otherwise, return None.
653    """
654    for maybe_obs_node in node.users.keys():
655        if maybe_obs_node.op == "call_module":
656            maybe_obs = modules[str(maybe_obs_node.target)]
657            if _is_activation_post_process(maybe_obs):
658                return maybe_obs
659    return None
660
661
662def convert_standalone_module(
663    node: Node,
664    modules: Dict[str, torch.nn.Module],
665    model: torch.fx.GraphModule,
666    is_reference: bool,
667    backend_config: Optional[BackendConfig],
668) -> None:
669    """Converts a observed standalone module to a quantized standalone module by calling
670    the fx convert api, currently using the same `is_reference` flag as parent, but we may
671    changing this behavior in the future (e.g. separating quantization and lowering for
672    standalone module as well)
673
674    Args:
675      - node: The call_module node of the observed standalone module
676      - modules: named_module of original model
677      - model: original model
678      - is_reference: a flag from parent provided by user to decide if we want to
679        produce a reference model or a fbgemm/qnnpack model
680      - backend_config: backend configuration of the target backend of quantization
681    """
682    # TODO: remove is_reference flag
683    if is_reference:
684        convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx
685    else:
686        convert_fn = torch.ao.quantization.quantize_fx.convert_fx  # type: ignore[attr-defined]
687    # We know that observed standalone module is a GraphModule since
688    # it's produced by us
689    observed_standalone_module: GraphModule = modules[str(node.target)]  # type: ignore[assignment]
690    sm_input_quantized_idxs = observed_standalone_module.meta[
691        "_observed_graph_module_attrs"
692    ].standalone_module_input_quantized_idxs
693    # remove the dequantize nodes for inputs
694    args = list(node.args)
695    for idx in range(len(args)):
696        if idx in sm_input_quantized_idxs:
697            arg = args[idx]
698            if arg.op == "call_method" and arg.target == "dequantize":  # type: ignore[union-attr]
699                quantize_node = arg.args[0]  # type: ignore[union-attr]
700                node.replace_input_with(arg, quantize_node)
701                if len(arg.users) == 0:  # type: ignore[union-attr]
702                    model.graph.erase_node(arg)
703    # add dequantize node for output
704    sm_output_quantized_idxs = observed_standalone_module.meta[
705        "_observed_graph_module_attrs"
706    ].standalone_module_output_quantized_idxs
707    if len(sm_output_quantized_idxs) > 0:
708        assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
709        "output idxs = [0] is supported"
710
711        # if it's non-empty, then it means the output is kept in quantized form
712        # we'll just add a dequantize node after this node
713        _insert_dequantize_node(node, model.graph)
714
715    # TODO: allow convert_custom_config to override backend_config
716    # for standalone module
717    quantized_standalone_module = convert_fn(
718        observed_standalone_module, backend_config=backend_config
719    )
720    parent_name, name = _parent_name(node.target)
721    # update the modules dict
722    setattr(modules[parent_name], name, quantized_standalone_module)
723    modules[str(node.target)] = quantized_standalone_module
724
725
726def convert_weighted_module(
727    node: Node,
728    modules: Dict[str, torch.nn.Module],
729    observed_node_names: Set[str],
730    node_name_to_qconfig: Dict[str, QConfigAny],
731    backend_config: BackendConfig,
732    is_decomposed: bool = False,
733    is_reference: bool = False,
734) -> None:
735    """Convert a weighted module to reference quantized module in the model
736    If the QConfig of a QAT module is not set, the module will still be converted to
737    a float module.
738
739    Args:
740      - node: The call_module node of the observed standalone module
741      - modules: named_module of original model
742      - observed_node_names: names for the set of observed fx node, we can skip
743        this conversion if the node is not observed
744    """
745    original_module = modules[str(node.target)]
746    qconfig: QConfigAny = original_module.qconfig  # type: ignore[assignment]
747    weight_post_process = None
748    qat_module_classes = get_qat_module_classes(backend_config)
749
750    if isinstance(original_module, qat_module_classes):
751        # Converting qat module to a float module, we need to attach
752        # weight fake_quant to the module, weight fake_quant is assumed to be run during
753        # QAT so we don't need to run it again here
754        weight_post_process = original_module.weight_fake_quant
755        original_module = original_module.to_float()  # type: ignore[operator]
756        # change qat module to float module
757        parent_name, name = _parent_name(node.target)
758        setattr(modules[parent_name], name, original_module)
759
760    is_observed = node.name in observed_node_names
761    # If a qconfig is not defined for this node, then skip converting to a reference module
762    if (
763        qconfig is None
764        or _has_none_qconfig(node, node_name_to_qconfig)
765        or not is_observed
766    ):
767        return
768
769    # skip converting to reference quantized module if the qconfig is not supported
770    pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
771    dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
772    if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
773        return
774
775    # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
776    is_weight_quantized = weight_is_quantized(qconfig)
777
778    # the condition for swapping the module to reference quantized module is:
779    # weights need to be quantized
780    if not is_weight_quantized:
781        return
782
783    fused_module = None
784    float_module = original_module
785    # extract the individual float_module and fused module
786    if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule):
787        fused_module = float_module
788        float_module = fused_module[0]  # type: ignore[index]
789
790    # TODO: move this to the reference quantized module
791    # weight_qparams or weight_qparams dict
792    wq_or_wq_dict = {"is_decomposed": is_decomposed}
793    if isinstance(float_module, torch.nn.RNNCellBase):
794        weight_post_process_ih = qconfig.weight()  # type: ignore[union-attr, operator]
795        weight_post_process_hh = qconfig.weight()  # type: ignore[union-attr, operator]
796        weight_post_process_ih(float_module.weight_ih)
797        weight_post_process_hh(float_module.weight_hh)
798        weight_qparams_ih = get_qparam_dict(weight_post_process_ih)
799        weight_qparams_hh = get_qparam_dict(weight_post_process_hh)
800        wq_or_wq_dict.update(
801            {
802                "weight_ih": weight_qparams_ih,
803                "weight_hh": weight_qparams_hh,
804            }
805        )
806    elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)):
807        # format for wq_or_wq_dict (flattened attributes):
808        # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...}
809        for wn in float_module._flat_weights_names:
810            if hasattr(float_module, wn) and wn.startswith("weight"):
811                weight = getattr(float_module, wn)
812                weight_post_process = qconfig.weight()  # type: ignore[union-attr, operator]
813                if weight_post_process.dtype == torch.qint8:  # type: ignore[union-attr]
814                    weight_post_process(weight)  # type: ignore[operator, misc]
815                wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process)
816    else:
817        # weight_post_process is None means the original module is not a QAT module
818        # we need to get weight_post_process from qconfig in this case
819        is_ptq = weight_post_process is None
820        if is_ptq:
821            weight_post_process = qconfig.weight()  # type: ignore[union-attr, operator]
822            device = assert_and_get_unique_device(float_module)
823            if device:
824                weight_post_process.to(device)
825
826        # Call weight observer/fake_quant at least once to ensure the scales and zero points
827        # have the right shapes. Note: there are two cases where we don't have to do this:
828        #
829        # (1) QAT: The model's forward method already calls the weight observer/fake_quant,
830        #     and this typically happens during training, so we don't need to do it here.
831        #
832        # (2) Non-reference (lowered) case: The quantized module's from_float method already
833        #     calls the weight observer/fake_quant, so we don't have to do it here.
834        #
835        # Currently we ignore both cases and call the weight observer/fake_quant here
836        # regardless, which is technically incorrect. For (1), this is mainly to preserve BC
837        # in test code, which may not always train before convert. In the future, we should
838        # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941.
839        #
840        # For PT2, however, we don't need to preserve BC here, so we can skip this hack
841        # for QAT. We identify this case as (is_decomposed + is_reference + is_qat).
842        # Note that we still need it for PTQ in the PT2 flow since the model's forward
843        # method doesn't call the weight observer.
844        is_qat = not is_ptq
845        if not (is_decomposed and is_reference and is_qat):
846            weight_post_process(float_module.weight)  # type: ignore[operator]
847
848        wq_or_wq_dict.update(get_qparam_dict(weight_post_process))
849
850    # We use the same reference module for all modes of quantization: static, dynamic, weight_only
851    # root_module_to_quantized_reference_module: module mapping from root (floating point) module class
852    # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
853    root_module_to_quantized_reference_module = (
854        get_root_module_to_quantized_reference_module(backend_config)
855    )
856    ref_qmodule_cls = root_module_to_quantized_reference_module.get(
857        type_before_parametrizations(float_module), None
858    )
859    assert (
860        ref_qmodule_cls is not None
861    ), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
862    ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict)  # type: ignore[attr-defined]
863    if fused_module is not None:
864        fused_module[0] = ref_qmodule  # type: ignore[operator]
865    else:
866        parent_name, name = _parent_name(node.target)
867        setattr(modules[parent_name], name, ref_qmodule)
868
869
870def _remove_previous_dequantize_in_custom_module(
871    node: Node, prev_node: Node, graph: Graph
872) -> None:
873    """
874    Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
875
876    Before: quantize - dequantize - custom_module
877    After: quantize - custom_module
878                 \\ - dequantize
879    """
880    # expecting the input node for a custom module node to be a Node
881    assert isinstance(
882        prev_node, Node
883    ), f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
884    if prev_node.op == "call_method" and prev_node.target == "dequantize":
885        node.replace_input_with(prev_node, prev_node.args[0])
886        # Remove the dequantize node if it doesn't have other users
887        if len(prev_node.users) == 0:
888            graph.erase_node(prev_node)
889
890
891def convert_custom_module(
892    node: Node,
893    graph: Graph,
894    modules: Dict[str, torch.nn.Module],
895    custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]],
896    statically_quantized_custom_module_nodes: Set[Node],
897) -> None:
898    """Converts an observed custom module to a quantized custom module based on
899    `custom_module_class_mapping`
900    For static quantization, we'll also remove the previous `dequantize` node and
901    attach the observer node for output to the module, the observer for the node
902    will be converted to a dequantize node instead of quantize-dequantize pairs
903    later in the graph. In the end we would have a quantized custom module that
904    has the same interface as a default quantized module in nn.quantized namespace,
905    i.e. quantized input and quantized output.
906
907    Args:
908      - node: The call_module node of the observed standalone module
909      - graph: The graph containing the node
910      - modules: named_module of original model
911      - custom_module_class_mapping: mapping from observed custom module class to
912        quantized custom module class, used to swap custom modules
913      - statically_quantized_custom_module_nodes: we'll add the custom module node
914        if we find it is statically quantized, this will be used later when converting
915        observers to quant/dequant node pairs, if the observed node is a statically
916        quantized custom module nodes, we'll convert the observer to a dequantize node,
917        this is to keep the interface the same as the default quantized module.
918        TODO: maybe we want to redesign this part to align with reference model design
919        as well, but there has been some discussions around the interface, so we can do
920        it later.
921    """
922    observed_custom_module = modules[str(node.target)]
923    maybe_obs = _maybe_get_observer_for_node(node, modules)
924    qconfig = observed_custom_module.qconfig
925    if activation_is_statically_quantized(qconfig):
926        statically_quantized_custom_module_nodes.add(node)
927        if _is_custom_module_lstm(node, modules):
928            # The inputs are tuples in the form (input, (hidden0, hidden1))
929            # Ensure all three input nodes are quantized
930            assert (
931                len(node.args) == 2
932                and isinstance(node.args[1], tuple)
933                and len(node.args[1]) == 2
934            )
935            (inputs, (hidden0, hidden1)) = node.args  # type: ignore[misc]
936            assert isinstance(inputs, Node)
937            assert isinstance(hidden0, Node)
938            assert isinstance(hidden1, Node)
939            _remove_previous_dequantize_in_custom_module(node, inputs, graph)
940            _remove_previous_dequantize_in_custom_module(node, hidden0, graph)
941            _remove_previous_dequantize_in_custom_module(node, hidden1, graph)
942        elif _is_custom_module_mha(node, modules):
943            # Inputs are in the form (query, key, value)
944            # TODO: This is the first step in enabling the full fx custom module
945            # quantization path for MultiheadAttention, and only covers the inputs
946            # to the module.
947            # Additional handling is yet to be implemented for the outputs, similar
948            # to LSTM custom module
949            assert len(node.args) == 3
950            query, key, value = node.args
951            assert isinstance(query, Node)
952            assert isinstance(key, Node)
953            assert isinstance(value, Node)
954            _remove_previous_dequantize_in_custom_module(node, query, graph)
955            _remove_previous_dequantize_in_custom_module(node, key, graph)
956            _remove_previous_dequantize_in_custom_module(node, value, graph)
957        else:
958            # remove the previous dequant node to ensure the inputs are quantized
959            arg = node.args[0]
960            assert isinstance(arg, Node)
961            _remove_previous_dequantize_in_custom_module(node, arg, graph)
962            # absorb the following observer into the module conversion
963            activation_post_process = _maybe_get_observer_for_node(node, modules)
964            assert activation_post_process is not None
965            observed_custom_module.activation_post_process = activation_post_process
966
967    # swap the observed custom module to quantized custom module
968    quantized_custom_module_class = get_swapped_custom_module_class(
969        observed_custom_module, custom_module_class_mapping, qconfig
970    )
971    quantized_custom_module = quantized_custom_module_class.from_observed(
972        observed_custom_module
973    )
974    parent_name, name = _parent_name(node.target)
975    setattr(modules[parent_name], name, quantized_custom_module)
976
977
978def convert(
979    model: GraphModule,
980    is_reference: bool = False,
981    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
982    is_standalone_module: bool = False,
983    _remove_qconfig_flag: bool = True,
984    qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
985    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
986    is_decomposed: bool = False,
987) -> GraphModule:
988    """
989    We will convert an observed model (a module with observer calls) to a reference
990    quantized model, the rule is simple:
991    1. for each observer module call in the graph, we'll convert it to calls to
992       quantize and dequantize functions based on the observer instance
993    2. for weighted operations like linear/conv, we need to convert them to reference
994       quantized module, this requires us to know whether the dtype configured for the
995       weight is supported in the backend, this is done in prepare step and the result
996       is stored in observed_node_names, we can decide whether we need to swap the
997       module based on this set
998
999    Args:
1000       * `is_standalone_module`: when this flag is True, it means we are quantizing
1001       a submodule that is not inlined in parent module, and will be quantized
1002       separately as one unit.
1003
1004       * `is_decomposed`: a boolean flag to indicate whether we want to use the
1005        quantize operator for decomposed quantized tensor
1006        (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
1007        quantized tensor (torch.quantize_per_tensor)
1008
1009    Returns:
1010         a quantized standalone module, whether input/output is quantized is
1011         specified by prepare_custom_config, with
1012         input_quantized_idxs, output_quantized_idxs, please
1013         see docs for :func:`~torch.ao.quantization.prepare_fx` for details
1014    """
1015    if convert_custom_config is None:
1016        convert_custom_config = ConvertCustomConfig()
1017
1018    if isinstance(convert_custom_config, dict):
1019        warnings.warn(
1020            "Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
1021            "in a future version. Please pass in a ConvertCustomConfig instead.",
1022            FutureWarning,
1023            stacklevel=2,
1024        )
1025        convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
1026
1027    if isinstance(qconfig_mapping, dict):
1028        warnings.warn(
1029            "Passing a QConfig dictionary to convert is deprecated and will not be supported "
1030            "in a future version. Please pass in a QConfigMapping instead.",
1031            FutureWarning,
1032            stacklevel=2,
1033        )
1034        qconfig_mapping = (
1035            QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
1036        )
1037    qconfig_mapping = copy.deepcopy(qconfig_mapping)
1038    assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
1039
1040    if isinstance(backend_config, dict):
1041        warnings.warn(
1042            "Passing a backend_config_dict to prepare is deprecated and will not be supported "
1043            "in a future version. Please pass in a BackendConfig instead.",
1044            FutureWarning,
1045            stacklevel=2,
1046        )
1047        backend_config = BackendConfig.from_dict(backend_config)
1048
1049    if backend_config is None:
1050        backend_config = get_native_backend_config()
1051
1052    assert _is_observed_module(model), "incoming model must be produced by prepare_fx"
1053    observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
1054    node_name_to_scope: Dict[
1055        str, Tuple[str, type]
1056    ] = observed_graph_module_attrs.node_name_to_scope
1057    prepare_custom_config: PrepareCustomConfig = (
1058        observed_graph_module_attrs.prepare_custom_config
1059    )
1060    observed_node_names: Set[str] = observed_graph_module_attrs.observed_node_names
1061    node_name_to_qconfig: Dict[str, QConfigAny] = observed_graph_module_attrs.node_name_to_qconfig  # type: ignore[assignment]
1062
1063    # mapping from fully qualified module name to module instance
1064    # for example,
1065    # {
1066    #   '': Model(...),
1067    #   'linear': Linear(...),
1068    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
1069    # }
1070    # We use remove_duplicate=False here because torch.cat uses
1071    # the same activation_post_process module instance but different names
1072    modules = dict(model.named_modules(remove_duplicate=False))
1073
1074    # TODO refactor this code once we update the prepare logic to have additional information on
1075    # which graph nodes have been observed and share that with convert to decide which observers to ignore.
1076    if qconfig_mapping:
1077        prepare_qconfig_mapping: QConfigMapping = observed_graph_module_attrs.qconfig_mapping  # type: ignore[assignment]
1078        modules_copy = copy.deepcopy(modules)
1079
1080        if observed_graph_module_attrs.is_qat:
1081            _update_qconfig_for_qat(qconfig_mapping, backend_config)
1082        _update_qconfig_for_fusion(model, qconfig_mapping)
1083
1084        _compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping)  # type: ignore[arg-type]
1085        convert_node_name_to_qconfig = _generate_node_name_to_qconfig(
1086            model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope
1087        )
1088        # check the convert_node_name_to_qconfig generated and ensure that
1089        # all the values either match what was set in prepare node_name_to_qconfig
1090        # or are set to None in the convert_node_name_to_qconfig.
1091        for k, v in node_name_to_qconfig.items():
1092            assert (
1093                k in convert_node_name_to_qconfig
1094            ), f"Expected key {k} in convert node_name_to_qconfig"
1095            if convert_node_name_to_qconfig[k] is not None:
1096                assert qconfig_equals(v, convert_node_name_to_qconfig[k]), (
1097                    f"Expected k {k} to have the same value in prepare and convert QConfigMappings, "
1098                    f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
1099                )
1100        node_name_to_qconfig = convert_node_name_to_qconfig
1101
1102    custom_module_classes = get_custom_module_class_keys(
1103        convert_custom_config.observed_to_quantized_mapping
1104    )
1105    custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping
1106
1107    if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None:
1108        # If we want to do equalization then do the following:
1109        # Calculate the equalization scale, update the observers with the scaled
1110        # inputs, and scale the weight
1111        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
1112        convert_eq_obs(model, modules, weight_eq_obs_dict)
1113
1114    # always run weight observers in the top level forward method
1115    # for dynamic quant ops or weight only quant ops
1116    _run_weight_observers(model, backend_config)
1117
1118    graph_inputs: List[str] = []
1119    for node in model.graph.nodes:
1120        if node.op == "placeholder":
1121            graph_inputs.append(node.name)
1122
1123    # additional state to override inputs to be quantized, if specified
1124    # by the user
1125    placeholder_node_seen_cnt = 0
1126    input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
1127    output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
1128
1129    root_module_to_quantized_reference_module = (
1130        get_root_module_to_quantized_reference_module(backend_config)
1131    )
1132    # convert tuples so that it can work with isinstance(module, tuple_of_classes)
1133    root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
1134    qat_module_classes = get_qat_module_classes(backend_config)
1135    fused_module_classes = get_fused_module_classes(backend_config)
1136    statically_quantized_custom_module_nodes: Set[Node] = set()
1137
1138    for node in list(model.graph.nodes):
1139        if node.op == "placeholder":
1140            cur_placeholder_node_idx = placeholder_node_seen_cnt
1141            placeholder_node_seen_cnt += 1
1142            if cur_placeholder_node_idx in input_quantized_idxs:
1143                # Inputs are assumed to be quantized if the user specified the
1144                # input_quantized_idxs override.
1145                # we need to dequantize the inputs since all operators took
1146                # floating point inputs in reference quantized models
1147                _insert_dequantize_node(node, model.graph)
1148        elif node.op == "output":
1149            # If the argument is empty we don't need to do anything
1150            if len(output_quantized_idxs) == 0:
1151                continue
1152            # Result are kept quantized if the user specified the
1153            # output_quantized_idxs override.
1154            # Remove the dequantize operator for the node in the end if any
1155            return_node = node
1156            output = node.args[0]
1157            # outputs can be Node, list, tuple, dict, other cases are not supported yet
1158            if isinstance(output, (list, tuple)):
1159                for idx in output_quantized_idxs:
1160                    _maybe_recursive_remove_dequantize(
1161                        output[idx], return_node, model.graph
1162                    )
1163            elif isinstance(output, (Node, dict)):
1164                # we treat dict as a single argument currently, but it can be extended
1165                # to support {"key": dtype} after we change output_quantized_idxs to
1166                # dict
1167                if 0 in output_quantized_idxs:
1168                    _maybe_recursive_remove_dequantize(output, return_node, model.graph)
1169            else:
1170                warnings.warn(
1171                    f"Unsupported node type for output_quantized_idxs: {type(output)}"
1172                )
1173        elif node.op == "call_module":
1174            mod = _get_module(node, modules)
1175            assert mod is not None
1176            if _is_activation_post_process(mod):
1177                observed_node = node.args[0]
1178                if observed_node in statically_quantized_custom_module_nodes:
1179                    _replace_observer_or_dequant_stub_with_dequantize_node(
1180                        node, model.graph
1181                    )
1182                else:
1183                    if is_decomposed:
1184                        _replace_observer_with_quantize_dequantize_node_decomposed(
1185                            model,
1186                            node,
1187                            modules,
1188                            node_name_to_scope,
1189                            node_name_to_qconfig,
1190                        )
1191                    else:
1192                        _replace_observer_with_quantize_dequantize_node(
1193                            model,
1194                            node,
1195                            modules,
1196                            node_name_to_scope,
1197                            node_name_to_qconfig,
1198                        )
1199            elif isinstance(mod, DeQuantStub):
1200                _replace_observer_or_dequant_stub_with_dequantize_node(
1201                    node, model.graph
1202                )
1203            elif _is_observed_standalone_module(mod):
1204                convert_standalone_module(
1205                    node, modules, model, is_reference, backend_config
1206                )
1207            # below this point `type_before_parametrizations` is used
1208            # instead of `type` to handle situations with fx quant + sparsity
1209            elif type_before_parametrizations(mod) in set(root_module_classes).union(
1210                qat_module_classes
1211            ).union(fused_module_classes):
1212                # extra check for fused module classes to make sure they are fused module classes
1213                # of target modules
1214                if (
1215                    type_before_parametrizations(mod) in fused_module_classes
1216                    and type_before_parametrizations(mod[0]) not in root_module_classes
1217                ):  # type: ignore[index]
1218                    continue
1219                convert_weighted_module(
1220                    node,
1221                    modules,
1222                    observed_node_names,
1223                    node_name_to_qconfig,
1224                    backend_config,
1225                    is_decomposed,
1226                    is_reference,
1227                )
1228            elif type_before_parametrizations(mod) in custom_module_classes:
1229                convert_custom_module(
1230                    node,
1231                    model.graph,
1232                    modules,
1233                    custom_module_class_mapping,
1234                    statically_quantized_custom_module_nodes,
1235                )
1236
1237    # remove deadcode after converting observers to quant/dequant ops
1238    model.graph.eliminate_dead_code()
1239    model = GraphModule(model, model.graph)
1240
1241    # TODO: maybe move this to quantize_fx.py
1242    if not is_reference:
1243        model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope)
1244
1245    # TODO: this looks hacky, we want to check why we need this and see if we can
1246    # remove this
1247    # removes qconfig and activation_post_process modules
1248    if _remove_qconfig_flag:
1249        _remove_qconfig(model)
1250    model.delete_all_unused_submodules()
1251    model.meta.pop("_observed_graph_module_attrs", None)
1252    return model
1253