xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/prepare.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import warnings
4from dataclasses import asdict
5from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
6
7import torch
8from torch._subclasses import FakeTensor
9from torch.ao.quantization import (
10    _DerivedObserverOrFakeQuantize,
11    FixedQParamsFakeQuantize,
12    FixedQParamsObserver,
13    ObserverBase,
14    ObserverOrFakeQuantize,
15    PlaceholderObserver,
16)
17from torch.ao.quantization.backend_config import (
18    BackendConfig,
19    DTypeConfig,
20    get_native_backend_config,
21)
22from torch.ao.quantization.backend_config.utils import (
23    get_fusion_pattern_to_root_node_getter,
24    get_module_to_qat_module,
25    get_pattern_to_dtype_configs,
26)
27from torch.ao.quantization.observer import _is_activation_post_process, _PartialWrapper
28from torch.ao.quantization.qconfig import _is_reuse_input_qconfig, QConfigAny
29from torch.ao.quantization.qconfig_mapping import QConfigMapping
30from torch.ao.quantization.quantize import convert, propagate_qconfig_
31from torch.ao.quantization.quantizer import (
32    DerivedQuantizationSpec,
33    EdgeOrNode,
34    FixedQParamsQuantizationSpec,
35    QuantizationSpec,
36    QuantizationSpecBase,
37    SharedQuantizationSpec,
38)
39from torch.ao.quantization.utils import (
40    _parent_name,
41    get_qconfig_dtypes,
42    get_swapped_custom_module_class,
43    NodePattern,
44    Pattern,
45)
46from torch.fx import GraphModule
47from torch.fx.graph import Graph, Node
48from torch.fx.node import Argument
49
50from ._equalize import is_equalization_observer, node_supports_equalization
51from .custom_config import PrepareCustomConfig, StandaloneModuleConfigEntry
52from .match_utils import _find_matches, _MatchResultWithQConfig
53from .pattern_utils import _sorted_patterns_dict
54from .qconfig_mapping_utils import (
55    _generate_node_name_to_qconfig,
56    _get_flattened_qconfig_dict,
57    _update_qconfig_for_fusion,
58    _update_qconfig_for_qat,
59)
60from .quantize_handler import (
61    _default_root_node_getter,
62    _get_pattern_to_quantize_handlers,
63    QuantizeHandler,
64)
65from .utils import (
66    _insert_dequant_stubs_for_custom_module_lstm_output,
67    _is_custom_module_lstm,
68    _maybe_get_custom_module_lstm_from_node_arg,
69    _qconfig_satisfies_dtype_config_constraints,
70    all_node_args_have_no_tensors,
71    assert_and_get_unique_device,
72    get_custom_module_class_keys,
73    get_new_attr_name_with_prefix,
74    get_non_observable_arg_indexes_and_types,
75    node_arg_is_bias,
76    node_arg_is_weight,
77    NON_QUANTIZABLE_WEIGHT_OPS,
78    ObservedGraphModuleAttrs,
79)
80
81
82__all__ = [
83    "insert_observers_for_model",
84    "prepare",
85    "propagate_dtypes_for_known_nodes",
86]
87
88
89# list of dtypes to not add observers to
90_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
91_OBS_DTYPE_LIST = [
92    torch.quint8,
93    torch.qint8,
94    torch.qint32,
95    torch.float16,
96    torch.uint8,
97    torch.int8,
98    torch.int16,
99    torch.int32,
100    torch.float8_e5m2,
101    torch.float8_e4m3fn,
102]
103
104_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
105
106# note: the following default target dtype info dicts are temporary,
107# should be moved to the new programmable API class soon
108_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
109    "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
110    "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
111}
112
113_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = {
114    "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
115    "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
116}
117
118
119def _get_observer_kwargs(
120    quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec]
121):
122    kwargs_dict = asdict(quant_spec)
123    return copy.deepcopy(kwargs_dict)
124
125
126def _get_qspec_for_arg(
127    arg: Node,
128    input_qspec_map: Dict[Node, QuantizationSpecBase],
129    named_modules: Dict[str, torch.nn.Module],
130) -> Optional[QuantizationSpecBase]:
131    while _is_activation_post_process_node(arg, named_modules):
132        arg = arg.args[0]  # type: ignore[assignment]
133    return input_qspec_map.get(arg, None)
134
135
136def _create_obs_or_fq_from_qspec(
137    quantization_spec: Optional[QuantizationSpecBase],
138    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
139    is_qat: bool,
140):
141    """Create observer or fake quantize objects based on quantization spec
142
143    Args:
144       quantization_spec: used to store parameters to create the observer or fake quantizer
145       obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant
146       instance, it may be reused for different edge/output depending on configuration
147    """
148    if quantization_spec is None:
149        return None
150    if isinstance(quantization_spec, SharedQuantizationSpec):
151        edge_or_node = quantization_spec.edge_or_node
152        assert edge_or_node in obs_or_fq_map, (
153            "please make sure only refer to edge or node that has "
154            f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}"
155        )
156        return obs_or_fq_map[edge_or_node]
157    elif isinstance(quantization_spec, DerivedQuantizationSpec):
158        # can't use asdict, so not calling get_observer_kwargs here
159        kwargs = {
160            "dtype": quantization_spec.dtype,
161            "derive_qparams_fn": quantization_spec.derive_qparams_fn,
162            "quant_min": quantization_spec.quant_min,
163            "quant_max": quantization_spec.quant_max,
164            "qscheme": quantization_spec.qscheme,
165            "ch_axis": quantization_spec.ch_axis,
166        }
167        edge_or_nodes = quantization_spec.derived_from
168        obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
169        kwargs["obs_or_fqs"] = obs_or_fqs
170        return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
171    elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
172        kwargs = _get_observer_kwargs(quantization_spec)
173        observer_ctr = FixedQParamsObserver.with_args(**kwargs)
174        if is_qat:
175            return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)()
176        else:
177            return observer_ctr()
178
179    assert isinstance(quantization_spec, QuantizationSpec)
180    observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
181    kwargs = _get_observer_kwargs(quantization_spec)
182    kwargs.pop("observer_or_fake_quant_ctr")
183    # we will remove is_dynamic from QuantizationSpec because
184    # it seems that dynamic range quantization
185    obs_or_fq_class = observer_or_fake_quant_ctr
186    if isinstance(observer_or_fake_quant_ctr, _PartialWrapper):
187        obs_or_fq_class = observer_or_fake_quant_ctr.p.func  # type: ignore[union-attr, assignment]
188    if "PerChannel" not in obs_or_fq_class.__name__:  # type: ignore[operator, union-attr]
189        kwargs.pop("ch_axis")
190    return observer_or_fake_quant_ctr.with_args(**kwargs)()
191
192
193def _needs_obs_or_fq(
194    prev_output_dtype: Any,
195    prev_output_is_dynamic: bool,
196    cur_target_dtype: Any,
197    cur_target_is_dynamic: bool,
198    reuse_input_obs_or_fq: bool,
199    is_zeroth_arg: bool = False,
200) -> bool:
201    """
202    note: we will treat "not specified" as torch.float for now
203    utility function that checks if we should insert an observer or fake quant node
204    base on the requested dtype for the nodes from user
205
206    is_zeroth_arg: we only dynamically quantize the first arg of the node right now
207      this should be removed when we enable configuring dynamic quantization
208      for a specific argument, this can be removed if we deprecate fx graph mode
209      quantization
210
211    """
212
213    # need to insert placeholder observer for dynamic quantization so that it can
214    # be converted to choose_qparams -> q -> dq in convert step
215    if cur_target_is_dynamic:
216        assert (
217            cur_target_dtype in _OBS_DTYPE_LIST
218        ), f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
219        assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST
220        return is_zeroth_arg
221    if reuse_input_obs_or_fq:
222        return False
223    # non dynamic quantization
224    if cur_target_dtype in _OBS_DTYPE_LIST:
225        return (
226            prev_output_dtype in _OBS_DTYPE_LIST + [torch.float]
227            and cur_target_dtype != prev_output_dtype
228        )
229
230    # lots of error checking are skipped here for now
231    return False
232
233
234def _is_activation_post_process_node(
235    node: Node, named_modules: Dict[str, torch.nn.Module]
236) -> bool:
237    return (
238        isinstance(node, torch.fx.Node)
239        and node.op == "call_module"
240        and _is_activation_post_process(named_modules[str(node.target)])
241    )
242
243
244def _get_dtype_and_is_dynamic(
245    obs_or_fq: Optional[ObserverOrFakeQuantize],
246) -> Tuple[Optional[torch.dtype], bool]:
247    """Given a constructor for observer or fake quant module, returns
248    a Tuple of dtype and is_dynamic
249    """
250    # TODO: instead of instantiating the instance, we can use inspect to get the default args
251    if obs_or_fq is None:
252        return None, False
253    else:
254        return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False)  # type: ignore[return-value]
255
256
257def _is_input_arg_dtype_supported_by_backend(
258    arg: Argument,
259    node: Node,
260    qconfig: QConfigAny,
261    dtype_config: DTypeConfig,
262    backend_config: BackendConfig,
263) -> bool:
264    """Check if the configured qconfig for the argument
265    is supported by the backend or not
266    """
267    if isinstance(arg, (list, tuple)):
268        return all(
269            _is_input_arg_dtype_supported_by_backend(
270                a, node, qconfig, dtype_config, backend_config
271            )
272            for a in arg
273        )
274    if not isinstance(arg, Node):
275        return True
276    # TODO: support check for standalone module
277    is_weight = node_arg_is_weight(node, arg)
278    is_bias = node_arg_is_bias(node, arg)
279    is_activation = not is_weight and not is_bias
280    if is_activation:
281        input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
282            "input_act_obs_or_fq_ctr"
283        )
284        input_act_obs_or_fq = (
285            input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None
286        )
287        qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(
288            input_act_obs_or_fq
289        )
290        # TODO(future PR): remove the cast to bool below after figuring
291        # out why backend_config has is_dynamic set to None in some cases.
292        return (dtype_config.input_dtype is None) or (
293            dtype_config.input_dtype == qconfig_dtype
294            and bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic)
295            and _qconfig_satisfies_dtype_config_constraints(
296                qconfig, dtype_config.input_dtype_with_constraints
297            )
298        )
299    elif is_weight:
300        # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
301        weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
302            "weight_obs_or_fq_ctr", None
303        )
304        weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None
305        qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq)
306        backend_config_weight_dtype = dtype_config.weight_dtype
307        dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype
308        qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
309            qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False
310        )
311        return backend_config_weight_dtype is None or (
312            dtype_matches and qconfig_satisfies_constraints
313        )
314    else:  # bias
315        # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
316        bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
317            "bias_obs_or_fq_ctr", None
318        )
319        bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None
320        qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq)
321        backend_config_bias_dtype = dtype_config.bias_dtype
322        return (
323            backend_config_bias_dtype is None
324            or qconfig_bias_dtype == backend_config_bias_dtype
325        )
326
327
328def _is_output_dtype_supported_by_backend(
329    node: Node,
330    qconfig: QConfigAny,
331    dtype_config: DTypeConfig,
332) -> bool:
333    """Check if the configured qconfig for the output
334    is supported by the backend or not
335    """
336    # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
337    backend_config_output_dtype = dtype_config.output_dtype
338    # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend
339    # from input activation check can be reused here
340    qconfig_output_dtype = None
341    output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
342        "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
343    )
344    output_act_obs_or_fq = (
345        output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
346    )
347    qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(
348        output_act_obs_or_fq
349    )
350    # TODO: this is a hack because we can only specify one activation_obs_or_fq for
351    # qconfig (qconfig.activation), and we are only supporting dynamically quantized
352    # linear op which has fp32 output dtype, this should be removed if we generalize
353    # the structure of qconfig in the future
354    if qconfig_output_is_dynamic:
355        qconfig_output_dtype = torch.float32
356    dtype_matches = qconfig_output_dtype == backend_config_output_dtype
357    qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
358        qconfig, dtype_config.output_dtype_with_constraints
359    )
360    return backend_config_output_dtype is None or (
361        dtype_matches and qconfig_satisfies_constraints
362    )
363
364
365def _is_observer_in_same_graph(
366    node: Node,
367    named_modules: Dict[str, torch.nn.Module],
368    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
369    is_qat,
370):
371    """Check if observer in same graph
372    when the node output is not fp32 and input is 'placeholder'
373    the input is assumed to be quantized, so it is observed
374    in a different place rather than not observed.
375    """
376    node_output_dtype = _get_arg_target_dtype_as_output(
377        node, named_modules, obs_or_fq_map, is_qat
378    )
379    if len(node.args) > 0 and isinstance(node.args[0], Node):
380        if (
381            node_output_dtype in [torch.quint8, torch.uint8]
382            and node.args[0].op == "placeholder"
383        ):
384            return False
385    return True
386
387
388def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
389    pattern: Optional[Pattern],
390    matched_node_pattern: Optional[List[Node]],
391    qconfig: QConfigAny,
392    backend_config: BackendConfig,
393) -> bool:
394    """Check if the dtype configuration of a pattern is supported by
395    the backend or not, and whether the qconfig satisfies constraints
396    specified in the corresponding dtype config.
397    """
398    if backend_config is None or pattern is None:
399        return True
400    assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
401    pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
402    dtype_configs: List[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
403    pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
404
405    root_node_getter = pattern_to_root_node_getter.get(
406        pattern, _default_root_node_getter
407    )
408    root_node = root_node_getter(matched_node_pattern)
409    input_node = root_node
410    output_node = matched_node_pattern[0]
411    for dtype_config in dtype_configs:
412        # check if arg dtype are supported
413        supported = True
414        for arg in list(input_node.args) + list(input_node.kwargs.values()):
415            supported = supported and _is_input_arg_dtype_supported_by_backend(
416                arg, input_node, qconfig, dtype_config, backend_config
417            )
418        # check if output dtype is supported
419        supported = supported and _is_output_dtype_supported_by_backend(
420            output_node, qconfig, dtype_config
421        )
422        if supported:
423            return True
424    return False
425
426
427def _get_standalone_module_configs(
428    node: Node,
429    named_modules: Dict[str, torch.nn.Module],
430    prepare_custom_config: PrepareCustomConfig,
431    parent_qconfig: QConfigAny,
432    parent_backend_config: Optional[BackendConfig],
433) -> Tuple[
434    QConfigMapping, Tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig]
435]:
436    """
437    Returns the standalone module QConfigMapping and PrepareCustomConfig
438    for `node`, assuming that the module pointed to by `node` is
439    a standalone modules.
440    """
441    module_name = str(node.target)
442    module_type = type(named_modules[module_name])  # type: ignore[index]
443    # name config has precedence over type config
444    config_entry = StandaloneModuleConfigEntry(None, (), None, None)
445    config_entry = prepare_custom_config.standalone_module_classes.get(
446        module_type, config_entry
447    )
448    config_entry = prepare_custom_config.standalone_module_names.get(
449        module_name, config_entry
450    )
451    # fallback to use parent module's qconfig if user didn't specify qconfig dict
452    qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(
453        parent_qconfig
454    )
455    example_inputs = config_entry.example_inputs
456    prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig()
457    backend_config = config_entry.backend_config or parent_backend_config
458    return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
459
460
461def _qat_swap_modules(
462    root: torch.nn.Module, module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]]
463) -> None:
464    convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
465
466
467def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]):
468    if isinstance(matched_node_pattern, Node):
469        s.add(matched_node_pattern.name)
470    elif isinstance(matched_node_pattern, (list, tuple)):
471        for maybe_node in matched_node_pattern:
472            _add_matched_node_name_to_set(maybe_node, s)
473
474
475def _insert_obs_or_fq(
476    node: Node,
477    obs_or_fq: ObserverOrFakeQuantize,
478    model: torch.nn.Module,
479    named_modules: Dict[str, torch.nn.Module],
480    graph: Graph,
481) -> Node:
482    """
483    Attaches `obs_or_fq` to `model`, and creates a node which calls
484    `obs_or_fq` on the output of `node`.
485
486    obs_or_fq: an instance of Observer or FakeQuantize module
487    """
488    model_device = assert_and_get_unique_device(model)
489    if model_device:
490        obs_or_fq.to(model_device)
491    # add obs_or_fq module as attribute
492    if is_equalization_observer(obs_or_fq):
493        prefix = node.name + "_equalization_process_"
494    else:
495        prefix = "activation_post_process_"
496    get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix)
497    obs_or_fq_name = get_new_obs_or_fq_name(model)
498    setattr(model, obs_or_fq_name, obs_or_fq)
499    named_modules[obs_or_fq_name] = obs_or_fq
500    with graph.inserting_after(node):
501        new_obs = graph.create_node("call_module", obs_or_fq_name, (node,), {})
502    return new_obs
503
504
505def _set_target_dtype_info_for_matched_node_pattern(
506    matched_node_pattern: NodePattern,
507    last_node: Node,
508    qconfig: QConfigAny,
509    qhandler: Optional[QuantizeHandler],
510    backend_config: BackendConfig,
511    named_modules: Dict[str, torch.nn.Module],
512    cache_for_no_tensor_check: Dict[Node, bool],
513    processed_nodes: Set[Node],
514) -> None:
515    """Sets the target_dtype_info for each node in matched_node_pattern
516    Note: processed_nodes is used to ensure we only process each node once
517    """
518    if isinstance(matched_node_pattern, (list, tuple)):
519        for node_pattern in matched_node_pattern:
520            _set_target_dtype_info_for_matched_node_pattern(
521                node_pattern,
522                last_node,
523                qconfig,
524                qhandler,
525                backend_config,
526                named_modules,
527                cache_for_no_tensor_check,
528                processed_nodes,
529            )
530
531    # set target_dtype_info if matched_node_pattern is a Node
532    # other types of matched object, e.g. int, float literals, are ignored
533    elif isinstance(matched_node_pattern, Node):
534        # for pyre
535        assert isinstance(matched_node_pattern, Node)
536        node = matched_node_pattern
537        if node in processed_nodes:
538            return
539        processed_nodes.add(node)
540
541        if qconfig is None:
542            return
543        # TODO: refactor the following code in terms of apply a qconfig to a pattern
544        # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1)
545        # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act,
546        # and set output_obs_or_fq_ctr based on qconfig.output_act
547        # this also requires we extend the structure of QConfig to support more fine
548        # grained configurations
549        target_dtype_info: Dict[str, Any] = _get_target_activation_dtype_for_node(
550            node,
551            qconfig,
552            qhandler,
553            named_modules,
554            backend_config,
555            cache_for_no_tensor_check,
556        )
557        node.meta["target_dtype_info"] = target_dtype_info
558
559
560def _get_target_activation_dtype_for_node(
561    node: Node,
562    qconfig: QConfigAny,
563    qhandler: Optional[QuantizeHandler],
564    named_modules: Dict[str, torch.nn.Module],
565    backend_config: BackendConfig,
566    cache_for_no_tensor_check: Dict[Node, bool],
567) -> Dict[str, Any]:
568    """
569    For each op attribute in the op's input activation, output activation,
570    weight, bias - returns the settings of dtype and is_dynamic we expect
571    for the `quantize` call in the reference model representation, or None
572    if there is no `quantize` call needed.
573
574    For example, if we have a node corresponding to `op0` in
575
576      x0 -> op0 -> x1
577
578    And we want a reference quantized representation to be
579
580      x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1
581
582    Then this function will return
583
584      {
585        "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
586        "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
587      }
588
589    TODO(future PR, if needed): explicitly spell out the non-Tensor
590    dtypes.
591    """
592    args_have_no_tensors = all_node_args_have_no_tensors(
593        node, named_modules, cache_for_no_tensor_check
594    )
595    if args_have_no_tensors:
596        return {
597            "input_act_obs_or_fq_ctr": None,
598            "output_act_obs_or_fq_ctr": None,
599        }
600    # get qconfig to determine the eventual dtype of this node
601    if qconfig is not None:
602        act_dtype, weight_dtype, input_act_is_dynamic = get_qconfig_dtypes(qconfig)
603
604        # Currently `QConfig` only has one `activation` field.
605        # For static quantization, it is reused for both input
606        # and output activation. For dynamic quantization, this
607        # field is currently only used for the input activation,
608        # with the output activation being in fp32.
609        # In the future this may change as we add more fields
610        # to the `QConfig` object.
611        output_act_dtype = act_dtype if (not input_act_is_dynamic) else torch.float
612
613        bias_dtype = (
614            torch.float16
615            if (
616                act_dtype == torch.float16
617                and weight_dtype == torch.float16
618                and (not input_act_is_dynamic)
619            )
620            else torch.float
621        )
622
623        is_general_tensor_value_op = (
624            qhandler is not None and qhandler.is_general_tensor_value_op()
625        )
626
627        _is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
628
629        weight_index = None
630        if (
631            isinstance(node, Node)
632            and node.op == "call_function"
633            and node.target in backend_config._pattern_complex_format_to_config
634        ):
635            weight_index = backend_config._pattern_complex_format_to_config[
636                node.target
637            ]._input_type_to_index.get("weight")
638
639        bias_index = None
640        if (
641            isinstance(node, Node)
642            and node.op == "call_function"
643            and node.target in backend_config._pattern_complex_format_to_config
644        ):
645            bias_index = backend_config._pattern_complex_format_to_config[
646                node.target
647            ]._input_type_to_index.get("bias")
648
649        return {
650            "input_act_obs_or_fq_ctr": qconfig.activation,
651            "weight_obs_or_fq_ctr": qconfig.weight,
652            "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype),
653            "weight_index": weight_index,
654            "bias_index": bias_index,
655            "output_act_obs_or_fq_ctr": qconfig.activation,
656            "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig),
657            "input_output_share_observers": is_general_tensor_value_op,
658            "_is_standalone_module": _is_standalone_module,
659        }
660    return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
661
662
663def _get_output_act_obs_or_fq(
664    arg: Node,
665    named_modules: Dict[str, torch.nn.Module],
666    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
667    is_qat: bool,
668) -> ObserverOrFakeQuantize:
669    """Get the constructor for observer or fake quant object for
670    the argument in the original graph as the output of previous node,
671    skipping inserted observers
672
673    We are assuming that the observers are inserted correctly, and the dtype for
674    argument in quantized graph will match what is specified by the qconfig
675    """
676    assert isinstance(arg, Node)
677    if "quantization_annotation" in arg.meta:
678        return _create_obs_or_fq_from_qspec(
679            arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
680        )
681
682    # Custom module LSTM output is a tuple that we broke down into the internal nodes in order
683    # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
684    # Since we modified the graph in this case, we must trace back from the args through
685    # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
686    # not be able to accurately detect whether this node is a consumer of custom module LSTM.
687    custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(
688        arg, named_modules
689    )
690    output_act_obs_or_fq_ctr = None
691    if custom_module_lstm_node is not None:
692        output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"][
693            "output_act_obs_or_fq_ctr"
694        ]
695        output_act_obs_or_fq = (
696            output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
697        )
698    elif _is_activation_post_process_node(arg, named_modules):
699        observed_arg = arg.args[0]
700        assert isinstance(
701            observed_arg, Node
702        ), "Currently we only support observing Node"
703        if "quantization_annotation" in observed_arg.meta:
704            output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
705                observed_arg.meta["quantization_annotation"].output_qspec,
706                obs_or_fq_map,
707                is_qat,
708            )
709        else:
710            assert "target_dtype_info" in observed_arg.meta
711            output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][
712                "output_act_obs_or_fq_ctr"
713            ]
714            output_act_obs_or_fq = (
715                output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
716            )
717    else:
718        if "target_dtype_info" in arg.meta:
719            output_act_obs_or_fq_ctr = arg.meta["target_dtype_info"].get(
720                "output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
721            )
722        else:
723            output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
724        output_act_obs_or_fq = (
725            output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
726        )
727
728    return output_act_obs_or_fq
729
730
731def _get_arg_target_dtype_as_output(
732    arg: Node,
733    named_modules: Dict[str, torch.nn.Module],
734    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
735    is_qat: bool,
736) -> Optional[torch.dtype]:
737    arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(
738        arg, named_modules, obs_or_fq_map, is_qat
739    )
740    arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(
741        arg_as_output_act_obs_or_fq
742    )
743    return arg_as_output_target_dtype
744
745
746def _get_arg_as_input_act_obs_or_fq(
747    arg: Node,
748    node: Node,
749    named_modules: Dict[str, torch.nn.Module],
750    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
751    is_qat: bool,
752) -> Optional[ObserverOrFakeQuantize]:
753    """Get the observer or fake quant constructor for the Argument `arg`, as input
754    to Node `node`
755    """
756    assert isinstance(arg, Node)
757    # "input_qspec_map" is the more general design we'll use for pt2e path
758    # it is a map from input argument node to observer or fake quant constructor, for example
759    # for the following graph:
760    # x -> conv -> output
761    #
762    # we may annotate conv node like the following:
763    # conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...)
764    #
765    if "quantization_annotation" in node.meta:
766        input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
767        input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules)
768        if input_arg_qspec is None:
769            input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR()
770        else:
771            input_arg_obs_or_fq = _create_obs_or_fq_from_qspec(
772                input_arg_qspec, obs_or_fq_map, is_qat
773            )
774        return input_arg_obs_or_fq
775
776    # we can remove the following path in the future if fx graph mode quantization is
777    # no longer used
778    is_weight = node_arg_is_weight(node, arg)
779    is_bias = node_arg_is_bias(node, arg)
780    is_activation = not is_weight and not is_bias
781    obs_or_fq_ctr = None
782    if is_activation:
783        obs_or_fq_ctr = node.meta["target_dtype_info"].get(
784            "input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
785        )
786    elif is_weight:
787        if node.target not in NON_QUANTIZABLE_WEIGHT_OPS:
788            obs_or_fq_ctr = node.meta["target_dtype_info"].get(
789                "weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
790            )
791    else:
792        obs_or_fq_ctr = node.meta["target_dtype_info"].get(
793            "bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR
794        )
795    return obs_or_fq_ctr() if obs_or_fq_ctr else None
796
797
798def _maybe_insert_input_observer_for_arg_or_kwarg(
799    node: Union[Node, Any],
800    arg: Argument,
801    qconfig: QConfigAny,
802    model: torch.nn.Module,
803    named_modules: Dict[str, torch.nn.Module],
804    graph: Graph,
805    qhandler: Optional[QuantizeHandler],
806    prepare_custom_config: PrepareCustomConfig,
807    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
808    is_qat: bool,
809    backend_config: Optional[BackendConfig] = None,
810) -> Argument:
811    """
812    Given a `node` and an `arg`, inserts an input observer between
813    `node` and `arg` if necessary.
814    """
815    # for ops such as torch.cat([x0, x1]),
816    # traverse through the list
817    if isinstance(arg, (list, tuple)):
818        new_arg_to_return = []
819        for inner_arg in arg:
820            new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
821                node,
822                inner_arg,
823                qconfig,
824                model,
825                named_modules,
826                graph,
827                qhandler,
828                prepare_custom_config,
829                obs_or_fq_map,
830                is_qat,
831                backend_config,
832            )
833            new_arg_to_return.append(new_inner_arg)
834        return type(arg)(new_arg_to_return)
835
836    if not isinstance(arg, Node):
837        return arg
838    assert isinstance(arg, Node)
839    # default (no observer)
840    new_arg = arg
841
842    is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
843    # TODO: move this to a separate function
844    if not is_standalone_module:
845        # Note: qconfig can be None in this branch this we are getting act/fq from
846        # node.meta now
847        # regular flow for most nodes, except standalone modules
848
849        if "quantization_annotation" in node.meta:
850            reuse_input_obs_or_fq = node.meta[
851                "quantization_annotation"
852            ]._reuse_input_obs_or_fq
853        else:
854            assert "target_dtype_info" in node.meta
855            # TODO: we are assuming "target_dtype_info" exists here, maybe
856            # a default value also need to be provided here
857            target_dtype_info = node.meta["target_dtype_info"]
858            # for nodes that doesn't have `reuse_input_obs_or_fq` configured,
859            # we'll default to False, this makes configuring this field optional for users
860            reuse_input_obs_or_fq = target_dtype_info.get(
861                "reuse_input_obs_or_fq", False
862            )
863        arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(
864            arg, node, named_modules, obs_or_fq_map, is_qat
865        )
866        (
867            arg_as_input_target_dtype,
868            arg_as_input_target_is_dynamic,
869        ) = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq)
870
871        arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(
872            arg, named_modules, obs_or_fq_map, is_qat
873        )
874        (
875            arg_as_output_target_dtype,
876            arg_as_output_target_is_dynamic,
877        ) = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
878
879        needs_obs_or_fq = _needs_obs_or_fq(
880            arg_as_output_target_dtype,
881            arg_as_output_target_is_dynamic,
882            arg_as_input_target_dtype,
883            arg_as_input_target_is_dynamic,
884            reuse_input_obs_or_fq,
885            is_zeroth_arg=len(node.args) > 0 and arg is node.args[0],
886        )
887
888    else:
889        assert qconfig is not None
890        # custom flow for standalone modules
891        _, _, sm_prepare_custom_config, _ = _get_standalone_module_configs(
892            node, named_modules, prepare_custom_config, qconfig, backend_config
893        )
894        sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes
895
896        # for args, this is set to the index of the current arg
897        # for kwargs, this is left at None
898        cur_input_idx = None
899        for arg_idx, arg_to_check in enumerate(node.args):
900            if arg_to_check is arg:
901                cur_input_idx = arg_idx
902                break
903
904        if cur_input_idx is None:
905            needs_obs_or_fq = False
906        else:
907            arg_as_output_target_dtype = _get_arg_target_dtype_as_output(
908                arg, named_modules, obs_or_fq_map, is_qat
909            )
910            arg_as_input_target_dtype = (
911                torch.quint8
912                if cur_input_idx in sm_input_quantized_idxs
913                else torch.float
914            )
915            needs_obs_or_fq = (
916                arg_as_output_target_dtype != arg_as_input_target_dtype
917            ) and (arg_as_input_target_dtype != torch.float)
918
919        act_post_process_ctr = qconfig.activation
920        arg_as_input_act_obs_or_fq = (
921            act_post_process_ctr() if act_post_process_ctr else None
922        )
923
924    if needs_obs_or_fq:
925        existing_obs_node = None
926
927        # Before using the new observer, check if an observer
928        # of the correct type already exists. If it does, use it.
929        # This prevents duplicate observer insertions if a node is
930        # used by multiple nodes.
931        # TODO: this is looking into how the value is used in the future
932        # we should remove this
933        # removing this means we insert one observer for each use, even if they
934        # have the same dtype, we can have an extra pass that removes the extra observers
935        for maybe_obs_node in arg.users.keys():
936            if maybe_obs_node.op == "call_module":
937                maybe_obs_mod = named_modules[maybe_obs_node.target]  # type: ignore[index]
938                if (
939                    type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq)
940                    and maybe_obs_mod.dtype
941                    == arg_as_input_target_dtype  # type: ignore[possibly-undefined]
942                ):
943                    arg_as_input_act_obs_or_fq = maybe_obs_mod  # type: ignore[assignment]
944                    existing_obs_node = maybe_obs_node
945                    break
946
947        assert arg_as_input_act_obs_or_fq is not None
948        obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
949        if existing_obs_node is None:
950            new_obs_node = _insert_obs_or_fq(
951                arg, arg_as_input_act_obs_or_fq, model, named_modules, graph
952            )
953            # override this arg to be the observed arg
954            new_arg = new_obs_node
955        else:
956            new_arg = existing_obs_node
957
958    return new_arg
959
960
961def _maybe_insert_input_observers_for_node(
962    node: Node,
963    qconfig: QConfigAny,
964    model: torch.nn.Module,
965    named_modules: Dict[str, torch.nn.Module],
966    graph: Graph,
967    qhandler: Optional[QuantizeHandler],
968    prepare_custom_config: PrepareCustomConfig,
969    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
970    is_qat: bool,
971    backend_config: Optional[BackendConfig] = None,
972) -> None:
973    """
974    If needed, inserts observers to the input args and kwargs of `node`.
975    Note: modifies `node` inplace.
976
977    For example, if cur_node needs an observer after prev_node, we change from
978
979      prev_node -> cur_node
980
981    To
982
983      prev_node -> obs -> cur_node
984
985    Note: backend_config only needed for standalone_module node
986    """
987    # Look through every input arg.  If that arg's target dtype does not
988    # match the current node's target dtype, insert an observer.
989    new_args = []
990    for arg in node.args:
991        new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
992            node,
993            arg,
994            qconfig,
995            model,
996            named_modules,
997            graph,
998            qhandler,
999            prepare_custom_config,
1000            obs_or_fq_map,
1001            is_qat,
1002            backend_config,
1003        )
1004        new_args.append(new_arg)
1005
1006    new_kwargs = {}
1007    for k, kwarg in node.kwargs.items():
1008        new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg(
1009            node,
1010            kwarg,
1011            qconfig,
1012            model,
1013            named_modules,
1014            graph,
1015            qhandler,
1016            prepare_custom_config,
1017            obs_or_fq_map,
1018            is_qat,
1019            backend_config,
1020        )
1021        new_kwargs[k] = new_kwarg
1022
1023    # assign the new args and kwargs to the node, inplace
1024    node.args = tuple(new_args)
1025    node.kwargs = new_kwargs
1026
1027
1028def _maybe_insert_input_equalization_observers_for_node(
1029    node: Node,
1030    equalization_qconfig: Any,
1031    model: torch.nn.Module,
1032    named_modules: Dict[str, torch.nn.Module],
1033    graph: Graph,
1034    is_branch: bool,
1035) -> None:
1036    """
1037    If `node` needs to be equalized, find the input/weight observers it needs in
1038    `equalization_qconfig`, creates them, and inserts it into `graph`.
1039
1040    If `node` does not need an equalization observer, returns None.
1041    """
1042    if equalization_qconfig is None or not node_supports_equalization(
1043        node, named_modules
1044    ):
1045        return
1046
1047    if is_branch:
1048        warnings.warn(f"Cannot equalize {node} because it is part of a branch.")
1049        return
1050
1051    new_args = []
1052    for arg in node.args:
1053        if not isinstance(arg, Node) or node_arg_is_bias(node, arg):
1054            new_args.append(arg)
1055            continue
1056
1057        is_weight = node_arg_is_weight(node, arg)
1058
1059        act_eq_process_ctr = (
1060            equalization_qconfig.weight
1061            if is_weight
1062            else equalization_qconfig.input_activation
1063        )
1064
1065        new_eq_obs_mod = act_eq_process_ctr()
1066        new_eq_obs_node = _insert_obs_or_fq(
1067            arg, new_eq_obs_mod, model, named_modules, graph
1068        )
1069
1070        new_args.append(new_eq_obs_node)
1071
1072    # assign the new args and kwargs to the node, inplace
1073    node.args = tuple(new_args)
1074
1075
1076def _maybe_insert_output_observer_for_node(
1077    node: Node,
1078    model: torch.nn.Module,
1079    named_modules: Dict[str, torch.nn.Module],
1080    graph: Graph,
1081    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
1082    is_qat: bool,
1083) -> Optional[Node]:
1084    """
1085    If `node` needs an output observer, creates it, inserts it into `graph`
1086    and returns it.
1087
1088    If `node` does not need an output observer, returns None.
1089
1090    Note: inserting dynamic quantization ops for output is not supported in fx graph mode
1091    quantization code path right now
1092    """
1093    assert node.op != "output", "observer insertion for outputs is handled elsewhere"
1094
1095    is_standalone_module = False
1096    if "quantization_annotation" in node.meta:
1097        output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
1098            node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
1099        )
1100    else:
1101        assert "target_dtype_info" in node.meta
1102        is_standalone_module = node.meta["target_dtype_info"].get(
1103            "_is_standalone_module", False
1104        )
1105        output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get(
1106            "output_act_obs_or_fq_ctr"
1107        )
1108        output_act_obs_or_fq = (
1109            output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
1110        )
1111    target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
1112    # uncomment after we support reuse_input_obs_or_fq properly by having separate
1113    # implemntations for this key instead of reusing the input_output_share_observers
1114    # code
1115    # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
1116    # for now we set this to False since reuse_input_obs_or_fq for
1117    # the output of a node is implementation in the same code path as observer sharing,
1118    # we should refactor this part to make it clearer in the future
1119    # and we would be able to read this from config directly
1120    reuse_input_obs_or_fq = False
1121
1122    # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False
1123    # because the prev_output is the output of an fp32 op, althought technically
1124    # we should get the dtype of the output from node.meta["val"] in the future
1125    # if we deprecate fx graph mode quantization
1126    needs_obs_or_fq = _needs_obs_or_fq(
1127        torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq
1128    )
1129    # currently the activation in QConfig(activation=...,) is for both input
1130    # and output, and when the activation is configured to be dynamic quantization
1131    # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means
1132    # the input should by dynamically quantized, but output should not be quantized
1133    #
1134    # there is no way we can specify different observer/fq for input and output
1135    # activation through QConfig today, this limitation is lifted in the
1136    # quantizer/annotation API in pytorch 2.0 export quantization code path,
1137    # but since this code is reused, annotating output to be dynamically quantized
1138    # would not work either for that.
1139    # we can change QConfig to support input/output activation if we want
1140    # to remove the following check, or if we can deprecate fx graph mode quantization
1141    if target_is_dynamic:
1142        needs_obs_or_fq = False
1143
1144    # we never insert observers to output of standalone module, we assume
1145    # if needed, they are inserted inside the standalone module
1146    needs_obs_or_fq = needs_obs_or_fq and (not is_standalone_module)
1147
1148    if needs_obs_or_fq:
1149        obs_or_fq_map[node] = output_act_obs_or_fq
1150        return _insert_obs_or_fq(
1151            node, output_act_obs_or_fq, model, named_modules, graph
1152        )
1153    else:
1154        return None
1155
1156
1157def _maybe_insert_observers_before_graph_output(
1158    graph_output_node: Node,
1159    model: torch.nn.Module,
1160    named_modules: Dict[str, torch.nn.Module],
1161    graph: Graph,
1162    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
1163    is_qat: bool,
1164) -> None:
1165    """
1166    If the output needs to be quantized and there are any nodes
1167    in the output which are not already observed, inserts observers
1168    for those nodes.
1169    """
1170
1171    def _recursive_maybe_replace_node_with_obs(
1172        maybe_node: Argument,
1173        model: torch.nn.Module,
1174        named_modules: Dict[str, torch.nn.Module],
1175        graph: Graph,
1176    ) -> Argument:
1177        """
1178        Navigate an arbitrary data structure of lists, tuples, dicts.
1179        For each container type, recurse on all inputs. Once any Node
1180        is found, insert an observer if needed and do not recurse further.
1181
1182        For example, given a structure of
1183
1184          {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}}
1185
1186        we recurse down to bar1 and bar3, observe them if necessary,
1187        and if we inserted an observer then replace the original node
1188        with its observer.
1189
1190        Returns the data structure with all nodes needing observation being
1191        replaced by their observers.
1192        """
1193        if isinstance(maybe_node, Node):
1194            # check dtype of this node
1195            arg_as_output_target_dtype = _get_arg_target_dtype_as_output(
1196                maybe_node, named_modules, obs_or_fq_map, is_qat
1197            )
1198            observer_mod = None
1199            arg_as_input_target_dtype = torch.float
1200            if "target_dtype_info" in maybe_node.meta:
1201                observer_cls = maybe_node.meta["target_dtype_info"].get(
1202                    "input_act_obs_or_fq_ctr", None
1203                )
1204                if observer_cls is not None:
1205                    observer_mod = observer_cls()
1206                    arg_as_input_target_dtype = observer_mod.dtype
1207            # TODO: this does not handle dynamic quantization yet
1208            need_obs = (
1209                arg_as_output_target_dtype != arg_as_input_target_dtype
1210                and arg_as_input_target_dtype != torch.float
1211            )
1212            if need_obs:
1213                assert observer_mod is not None
1214                # insert observer
1215                observer_node = _insert_obs_or_fq(
1216                    maybe_node, observer_mod, model, named_modules, graph
1217                )
1218                return observer_node
1219            else:
1220                return maybe_node
1221        elif isinstance(maybe_node, (list, tuple)):
1222            results = []
1223            for inner_node in maybe_node:
1224                results.append(
1225                    _recursive_maybe_replace_node_with_obs(
1226                        inner_node, model, named_modules, graph
1227                    )
1228                )
1229            if isinstance(maybe_node, list):
1230                return results
1231            else:
1232                return tuple(results)
1233        elif isinstance(maybe_node, dict):
1234            results_dict = {}
1235            for k, inner_v in maybe_node.items():
1236                results_dict[k] = _recursive_maybe_replace_node_with_obs(
1237                    inner_v, model, named_modules, graph
1238                )
1239            return results_dict
1240        elif maybe_node is None:
1241            return None
1242        else:
1243            raise Exception(  # noqa: TRY002
1244                "Unhandled type for returned node:", maybe_node
1245            )
1246
1247    new_args = []
1248    for old_arg in graph_output_node.args:
1249        new_args.append(
1250            _recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph)
1251        )
1252
1253    graph_output_node.args = tuple(new_args)  # type: ignore[assignment]
1254
1255
1256def _maybe_propagate_dtype_for_node(
1257    node: Node,
1258    target_dtype: Union[torch.dtype, type],
1259    node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
1260) -> None:
1261    """
1262    Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node`
1263    is a general tensor shape op, also call this function recursively on
1264    the first argument, to propagate the dtype to the caller.
1265    """
1266    node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None
1267    node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None
1268    # if this is a copy node, propagate to first arg
1269    (
1270        root_node,
1271        _,
1272        pattern,
1273        qhandler,
1274        qconfig,
1275    ) = node_name_to_match_result_with_qconfig.get(
1276        node.name, (None, None, None, None, None)
1277    )
1278    # TODO: probably need to remove `is_general_tensor_value_op`
1279    if qhandler is not None and qhandler.is_general_tensor_value_op():
1280        prev_node = node.args[0]
1281        if isinstance(prev_node, Node):
1282            _maybe_propagate_dtype_for_node(
1283                prev_node, target_dtype, node_name_to_match_result_with_qconfig
1284            )
1285
1286
1287def propagate_dtypes_for_known_nodes(
1288    graph: Graph,
1289    node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
1290) -> None:
1291    """
1292    Currently we assume that inputs to the graph are either `torch.float` or
1293    `torch.quint8`, which is not always correct. For ops such as
1294    `x.masked_fill(mask, value)`, we know that the dtype of  `mask` is a
1295    `BoolTensor`. Propagate this information throughout the graph.
1296
1297    Note: not all dtypes in the graph will be correct after this pass, but a
1298    higher percentage of them will be correct. Hopefully in the future we can
1299    replace this with a better way to reason about dtypes of tensors.
1300    """
1301    for node in graph.nodes:
1302        non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node)
1303
1304        for arg_type in non_observable_arg_dict:
1305            non_observable_indices = non_observable_arg_dict[arg_type](node)
1306
1307            for index in non_observable_indices:
1308                arg = node.args[index]
1309
1310                # when an argument is a tuple, it does not show up as another node so we need to go through
1311                # all elements of the tuple manually
1312                if isinstance(arg, (tuple, list)):
1313                    arg_list = list(arg)
1314                else:
1315                    arg_list = [arg]
1316
1317                for cur_arg in arg_list:
1318                    # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated
1319                    if isinstance(cur_arg, torch.fx.node.Node):
1320                        _maybe_propagate_dtype_for_node(
1321                            cur_arg, arg_type, node_name_to_match_result_with_qconfig
1322                        )
1323
1324
1325def _maybe_make_input_output_share_observers(
1326    node: Node,
1327    model: torch.nn.Module,
1328    named_modules: Dict[str, torch.nn.Module],
1329) -> bool:
1330    """
1331    Ensures that we share an observer
1332    for all input arguments as well as the output argument. In detail, given
1333    a graph of
1334
1335      x0 -> obs0 -> op -> x2
1336                  /
1337      x1 -> obs1 /
1338
1339    where node obs0 points to observer instance observer0,
1340    obs1 points to observer1 and obs2 points to observer2, we make nodes obs1
1341    and ob2 point to observer0.
1342    Returns: whether the operation succeeded or not
1343    """
1344    first_arg = None
1345    # find the first non-Tensor arg
1346    for i in range(len(node.args)):
1347        if isinstance(node.args[i], (Node, list, tuple)):
1348            first_arg = node.args[i]
1349            break
1350
1351    # if there is no non-Tensor arg, return directly
1352    if first_arg is None:
1353        return False
1354
1355    if isinstance(first_arg, (list, tuple)):
1356        first_arg_arg = first_arg[0]
1357    elif isinstance(first_arg, Node):
1358        first_arg_arg = first_arg
1359    else:
1360        return False
1361
1362    # if we have a graph such as
1363    #   observed_node -> non_observed_node -> cat
1364    # we need to navigate up to the first observer
1365    iteration_guard = 0
1366    while not _is_activation_post_process_node(first_arg_arg, named_modules):
1367        if not isinstance(first_arg_arg, Node):
1368            return False
1369        # did not find an activation_post_process for the op
1370        if first_arg_arg.op == "placeholder":
1371            return False
1372        # trace back the args until we found the first Tensor/Node
1373        trace_back_node = None
1374        for i in range(len(first_arg_arg.args)):
1375            trace_back_node = first_arg_arg.args[i]
1376            if isinstance(trace_back_node, Node):
1377                break
1378        if trace_back_node is None:
1379            return False
1380        first_arg_arg = trace_back_node
1381
1382        iteration_guard += 1
1383        if iteration_guard > 10000:
1384            raise AssertionError("Unable to find observer of previous node")
1385
1386    assert isinstance(first_arg_arg, Node)
1387    target_to_use = first_arg_arg.target
1388    assert isinstance(target_to_use, str)
1389    obs_mod_to_use = named_modules[target_to_use]
1390
1391    if isinstance(first_arg, (list, tuple)):
1392        # set all other input observer nodes to use that module
1393        for input_idx, input_arg in enumerate(first_arg):
1394            if input_idx == 0:
1395                continue
1396            iteration_guard = 0
1397            while not _is_activation_post_process_node(input_arg, named_modules):
1398                # failed to trace back since no input arg for the current node
1399                if len(input_arg.args) < 1:
1400                    return False
1401                input_arg = input_arg.args[0]
1402                iteration_guard += 1
1403                if iteration_guard > 10000:
1404                    raise AssertionError("Unable to find observer of previous node")
1405
1406            parent_name, name = _parent_name(input_arg.target)
1407            setattr(named_modules[parent_name], name, obs_mod_to_use)
1408
1409    # set the output observer node to use that module
1410    for output_obs_node in node.users.keys():
1411        assert _is_activation_post_process_node(output_obs_node, named_modules)
1412        parent_name, name = _parent_name(output_obs_node.target)
1413        setattr(named_modules[parent_name], name, obs_mod_to_use)
1414
1415    # TODO(future PR): delete the orphaned observer modules
1416    return True
1417
1418
1419def _remove_output_observer(
1420    node: Node, model: torch.nn.Module, named_modules: Dict[str, torch.nn.Module]
1421):
1422    items = list(node.users.items())
1423    for output_obs_node, _ in items:
1424        assert _is_activation_post_process_node(output_obs_node, named_modules)
1425        output_obs_node.replace_all_uses_with(node)
1426        model.graph.erase_node(output_obs_node)  # type: ignore[union-attr, operator]
1427
1428
1429def _swap_custom_module_to_observed(
1430    node: Node,
1431    qconfig: QConfigAny,
1432    named_modules: Dict[str, torch.nn.Module],
1433    prepare_custom_config: PrepareCustomConfig,
1434):
1435    custom_module = named_modules[node.target]  # type: ignore[index]
1436    custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping
1437    observed_custom_module_class = get_swapped_custom_module_class(
1438        custom_module, custom_module_class_mapping, qconfig
1439    )
1440    observed_custom_module = observed_custom_module_class.from_float(custom_module)
1441    parent_name, name = _parent_name(node.target)
1442    setattr(named_modules[parent_name], name, observed_custom_module)
1443
1444
1445def insert_observers_for_model(
1446    model: GraphModule,
1447    node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
1448    node_name_to_qconfig: Dict[str, QConfigAny],
1449    prepare_custom_config: PrepareCustomConfig,
1450    equalization_config_map: Dict[str, Any],
1451    backend_config: BackendConfig,
1452    observed_node_names: Set[str],
1453    is_qat: bool,
1454) -> Optional[Node]:
1455    """
1456    Inserts observers, using the following high level algorithm:
1457
1458    For each node in the graph:
1459      1. determine the target dtype of this node in the quantized graph, and save
1460           it for future steps
1461      2. determine the target dtype or all args and kwargs of this node
1462      3. if any arg or kwarg's target dtype does not match the current node's
1463           dtype, insert an observer
1464      4. if the current node needs an output observer, insert it
1465
1466    For example:
1467
1468    - starting graph:
1469        x0 -> linear -> x1
1470
1471    - observed graph after processing x0:
1472        x0(fp32)
1473
1474    - observed graph after processing linear:
1475        x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)
1476
1477    - observed graph after processing x1:
1478        x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1
1479
1480    After a node is processed, the naive observer placement is guaranteed to be
1481    complete for that node and all of its predecessors. There can be future
1482    passes which optimize the graph by deduplicating observers, etc.
1483    """
1484
1485    # node.meta["target_dtype_info"] stores the target dtype information
1486    # that's derived from qconfig for the Node, for example, if we have
1487    # a conv2d node that has a qconfig
1488    # qconfig = QConfig(activation=..., weight=...)
1489    # # information for input and bias node omitted
1490    # # for getattr node
1491    # # weight = getattr(self, 'weight')
1492    # weight.meta["target_dtype_info"] = {
1493    #    'output_act_obs_or_fq_ctr': qconfig.weight,
1494    # }
1495    # # for conv2d node
1496    # # conv2d = call_function[target=torch.nn.functional.conv2d](
1497    # #            args=(input, weight, bias))
1498    # conv2d.meta["target_dtype_info"] = {
1499    #   'input_act_obs_or_fq_ctr': qconfig.activation
1500    #   'weight_obs_or_fq_ctr': qconfig.weight,
1501    #   'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32),
1502    #   'output_act_obs_or_fq_ctr': qconfig.activation,
1503    # }
1504    #
1505    cache_for_no_tensor_check: Dict[Node, bool] = {}
1506
1507    # first, populate the dtype map based only on qconfig and qhandler
1508    # this assumes:
1509    # graph inputs are fp32 by default, and int8 where overriden
1510    # other nodes output dtype is specified by the qconfig
1511    named_modules = dict(model.named_modules(remove_duplicate=False))
1512
1513    input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
1514    output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
1515    processed_nodes: Set[Node] = set()
1516    # initialize target_dtype_info
1517    for node in model.graph.nodes:
1518        node.meta["target_dtype_info"] = copy.copy(
1519            _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO
1520        )
1521
1522    inputs_seen_counter = 0
1523    outputs_seen_counter = 0
1524    placeholder_node_to_input_index: Dict[Node, int] = {}
1525    # TODO: we probably don't need this counter since each graph will only have
1526    # one output node?
1527    output_node_to_output_index: Dict[Node, int] = {}
1528    for node in model.graph.nodes:
1529        if node.op == "placeholder":
1530            placeholder_node_to_input_index[node] = inputs_seen_counter
1531            inputs_seen_counter += 1
1532        if node.op == "output":
1533            output_node_to_output_index[node] = outputs_seen_counter
1534            outputs_seen_counter += 1
1535
1536    # Step 1, set the observer or fake quantize module constructor for each node in the
1537    # matched_node_pattern
1538
1539    for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
1540        (
1541            last_node,
1542            matched_node_pattern,
1543            pattern,
1544            qhandler,
1545            qconfig,
1546        ) = match_res_with_qconfig
1547        assert qhandler is not None
1548        _set_target_dtype_info_for_matched_node_pattern(
1549            matched_node_pattern,
1550            last_node,
1551            qconfig,
1552            qhandler,
1553            backend_config,
1554            named_modules,
1555            cache_for_no_tensor_check,
1556            processed_nodes,
1557        )
1558
1559    # Step 2. Special cases for some operators, we might be able to remove them
1560    # in the future if we know dtype information of each node better
1561
1562    # Step 2.1. some settings are not based on patterns, we need to process each node
1563    # instead
1564    for node in model.graph.nodes:
1565        if (
1566            node.op == "placeholder"
1567            and placeholder_node_to_input_index[node] in input_quantized_idxs
1568        ):
1569            # users are not supposed to call calculate_qparams on PlaceholderObserver, and
1570            # this is OK because we are using this as a way to encode the dtypes of input
1571            # tensor, we won't actually insert these observers in the graph and won't
1572            # actually call calculate_qparams
1573            node.meta["target_dtype_info"] = copy.copy(
1574                _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO
1575            )
1576        elif node.op in ("call_module", "call_method", "call_function"):
1577            args_have_no_tensors = all_node_args_have_no_tensors(
1578                node, named_modules, cache_for_no_tensor_check
1579            )
1580            if args_have_no_tensors:
1581                node.meta["target_dtype_info"] = {
1582                    "input_act_obs_or_fq_ctr": None,
1583                    "output_act_obs_or_fq_ctr": None,
1584                }
1585        elif (
1586            node.op == "output"
1587            and output_node_to_output_index[node] in output_quantized_idxs
1588        ):
1589            # TODO(future PR): update the output_quantized_idxs API to match
1590            # arbitrary data structures. There is always a single output, and
1591            # that output can have arbitrary nesting of values. List[int] is
1592            # not the right data type for this.
1593
1594            # TODO(future PR): support more dtypes in model outputs, if necessary
1595            node.meta["target_dtype_info"] = copy.copy(
1596                _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO
1597            )
1598
1599    # Step 2.2, for nodes with known input dtypes, propagate them throughout the
1600    # graph. For example, if there is a call such as
1601    #   x1 = x0.masked_fill(mask, 1)
1602    # we propagate the type of mask to be torch.bool
1603    propagate_dtypes_for_known_nodes(
1604        model.graph, node_name_to_match_result_with_qconfig
1605    )
1606
1607    # Step 3, check if the requested target_dtype_info is supported by backend or not
1608    # if not, we'll reset the target_dtye_info to use the default (float Tensor)
1609
1610    # reset the counters and set of processed_nodes
1611    processed_nodes: Set[Node] = set()
1612    for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
1613        (
1614            last_node,
1615            matched_node_pattern,
1616            pattern,
1617            qhandler,
1618            qconfig,
1619        ) = match_res_with_qconfig
1620        is_supported_by_backend = (
1621            _is_pattern_dtype_config_and_qconfig_supported_by_backend(
1622                pattern, matched_node_pattern, qconfig, backend_config
1623            )
1624        )
1625        assert qhandler is not None
1626
1627        # get output_act_dtype so that we don't also reset the special typed nodes
1628        # TODO: we might want to handle these more uniformly with the default path
1629        # this can be improved if we can use node.meta["val"]
1630        output_act_or_fq_ctr = node.meta["target_dtype_info"][
1631            "output_act_obs_or_fq_ctr"
1632        ]
1633        output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None
1634        output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq)
1635        if not is_supported_by_backend and output_act_dtype not in [
1636            None,
1637            int,
1638            float,
1639            torch.bool,
1640        ]:
1641            # restore target_dtype_info to default if it is not supported by backend
1642            _set_target_dtype_info_for_matched_node_pattern(
1643                matched_node_pattern,
1644                last_node,
1645                torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig,
1646                None,
1647                backend_config,
1648                named_modules,
1649                cache_for_no_tensor_check,
1650                processed_nodes,
1651            )
1652
1653    # After this point, the current node and all of its arguments
1654    # have a target_dtype_info assigned. Now, we insert observers for inputs
1655    # of this node (if needed for this node), and the output of this node
1656    # (if needed for this node).
1657
1658    # Since we are mutating the graph as we go, we iterate over the original
1659    # nodes before observer insertion, instead of model.graph.nodes.
1660    nodes_before_observation = list(model.graph.nodes)
1661
1662    # Avoid duplicates custom module swaps for multiple nodes with same target.
1663    custom_module_names_already_swapped: Set[str] = set()
1664
1665    # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index
1666    # reset inputs/outputs counters
1667    inputs_seen_counter = 0
1668    outputs_seen_counter = 0
1669    results_node = None
1670    obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
1671
1672    # TODO: change this to insert obs/fq by pattern instead of by node
1673    for node in nodes_before_observation:
1674        if node.op == "placeholder":
1675            # if a graph input is in fp32, it does not need observation
1676            # if a graph input is in int8, we assume the observation happens
1677            #   outside of the graph, and no additional observation is needed
1678            pass
1679
1680        elif node.op in ("call_module", "call_method", "call_function", "output"):
1681            # check for matches
1682            (
1683                last_node,
1684                matched_node_pattern,
1685                pattern,
1686                qhandler,
1687                qconfig,
1688            ) = node_name_to_match_result_with_qconfig.get(  # type: ignore[assignment]
1689                node.name, (None, None, None, None, None)
1690            )
1691            equalization_qconfig = equalization_config_map.get(node.name, None)
1692
1693            this_node_dtype_info = node.meta["target_dtype_info"]
1694            if "val" in node.meta:
1695                output_is_a_tensor = this_node_dtype_info is not None and isinstance(
1696                    node.meta["val"], FakeTensor
1697                )
1698            else:
1699                output_is_a_tensor = this_node_dtype_info is not None
1700
1701            skip_inserting_observers = (
1702                (qconfig is None) or not output_is_a_tensor
1703            ) and (not node.op == "output")
1704
1705            # TODO: take a closer look to see if we can remove this check
1706            # right now it is here because of `observed_node_names`, we are using
1707            # it as an indicator for swapping the modules to reference modules in
1708            # convert
1709            is_supported_by_backend = (
1710                _is_pattern_dtype_config_and_qconfig_supported_by_backend(
1711                    pattern, matched_node_pattern, qconfig, backend_config
1712                )
1713            )
1714
1715            if not skip_inserting_observers and is_supported_by_backend:
1716                named_modules = dict(model.named_modules(remove_duplicate=False))
1717                if node.op != "output":
1718                    assert matched_node_pattern is not None
1719                    # add matched nodes to the observed node name set
1720                    _add_matched_node_name_to_set(
1721                        matched_node_pattern, observed_node_names
1722                    )
1723
1724                    # This is currently only used for equalization.
1725                    # Checks if the current node is in a branch in which the two
1726                    # first layers are both being quantized.
1727                    #
1728                    # ex.       conv2
1729                    #         /
1730                    #      x -> conv1
1731                    #
1732                    # If this is the case, we will not apply equalization to the
1733                    # initial two layers.
1734                    is_quantized_branch = False
1735                    if (
1736                        len(node.args) > 0
1737                        and isinstance(node.args[0], Node)
1738                        and len(node.args[0].users) > 1
1739                    ):
1740                        for user in node.args[0].users:
1741                            # Checks if there exists another user being quantized
1742                            is_user_quantized = node_name_to_qconfig.get(
1743                                user.name, None
1744                            ) is not None or (
1745                                user.op == "call_module"
1746                                and isinstance(
1747                                    named_modules[str(user.target)], ObserverBase
1748                                )
1749                            )
1750                            if user != node and is_user_quantized:
1751                                is_quantized_branch = True
1752
1753                    pattern_to_root_node_getter = (
1754                        get_fusion_pattern_to_root_node_getter(backend_config)
1755                    )
1756                    root_node_getter = pattern_to_root_node_getter.get(
1757                        pattern, _default_root_node_getter
1758                    )
1759                    root_node = root_node_getter(matched_node_pattern)
1760                    is_input_node_of_the_pattern = node is root_node
1761                    if is_input_node_of_the_pattern:
1762                        # this modifies node inplace
1763                        _maybe_insert_input_observers_for_node(
1764                            node,
1765                            qconfig,
1766                            model,
1767                            named_modules,
1768                            model.graph,
1769                            qhandler,
1770                            prepare_custom_config,
1771                            obs_or_fq_map,
1772                            is_qat,
1773                            backend_config,
1774                        )
1775
1776                        # insert equalization input observers if needed
1777                        _maybe_insert_input_equalization_observers_for_node(
1778                            node,
1779                            equalization_qconfig,
1780                            model,
1781                            named_modules,
1782                            model.graph,
1783                            is_quantized_branch,
1784                        )
1785
1786                    is_last_node_of_pattern = node is last_node
1787                    input_output_share_observers = node.meta["target_dtype_info"].get(
1788                        "input_output_share_observers", False
1789                    )
1790                    reuse_input_obs_or_fq = node.meta["target_dtype_info"].get(
1791                        "reuse_input_obs_or_fq", False
1792                    )
1793
1794                    if is_last_node_of_pattern:
1795                        if _is_custom_module_lstm(
1796                            node, named_modules, qconfig, qhandler
1797                        ):
1798                            # Currently custom module outputs are assumed to be already quantized,
1799                            # so we need to insert a DeQuantStub after the output. For custom module
1800                            # LSTM specifically, the outputs are also a nested tuple, so we must first
1801                            # break down the tuple to insert DeQuantStubs after the internal nodes.
1802
1803                            # TODO: This currently diverges from how custom modules are handled today,
1804                            # where we insert observers after the output instead of DeQuantStubs, and
1805                            # replace these observers with "dequantize" nodes during convert. Conceptually,
1806                            # these output observers are the same as DeQuantStubs. In the future, we
1807                            # should resolve this inconsistency by inserting DeQuantStubs for all custom
1808                            # modules, not just for LSTM.
1809                            _insert_dequant_stubs_for_custom_module_lstm_output(
1810                                node, model, named_modules, model.graph
1811                            )
1812                            if node.target not in custom_module_names_already_swapped:
1813                                custom_module_names_already_swapped.add(node.target)
1814                                _swap_custom_module_to_observed(
1815                                    node, qconfig, named_modules, prepare_custom_config
1816                                )
1817                        else:
1818                            # this returns the new observer node if it was needed
1819                            maybe_output_obs_node = (
1820                                _maybe_insert_output_observer_for_node(
1821                                    node,
1822                                    model,
1823                                    named_modules,
1824                                    model.graph,
1825                                    obs_or_fq_map,
1826                                    is_qat,
1827                                )
1828                            )
1829
1830                            if maybe_output_obs_node is not None:
1831                                # Update users of original node to use the output observer
1832                                # instead. For example, change
1833                                #
1834                                #           next_node
1835                                #          /
1836                                #   cur_node -> obs
1837                                #
1838                                # to
1839                                #
1840                                #                 next_node
1841                                #                 /
1842                                #   cur_node -> obs
1843                                #
1844                                # We need to save orig users before updating uses because
1845                                # the list of users will change as we update uses
1846                                orig_users = list(node.users.keys())
1847                                for user_node in orig_users:
1848                                    if user_node is maybe_output_obs_node:
1849                                        continue
1850                                    user_node.replace_input_with(
1851                                        node, maybe_output_obs_node
1852                                    )
1853
1854                                _is_observer_in_same_graph_ = (
1855                                    _is_observer_in_same_graph(
1856                                        node, named_modules, obs_or_fq_map, is_qat
1857                                    )
1858                                )
1859
1860                                # for ops whose inputs and outputs share observer/fqs, we modify the graph
1861                                # to make all inputs and outputs use the first input's
1862                                # observer/fq
1863                                if (
1864                                    input_output_share_observers
1865                                    and _is_observer_in_same_graph_
1866                                ) or reuse_input_obs_or_fq:
1867                                    if not _maybe_make_input_output_share_observers(
1868                                        node, model, named_modules
1869                                    ):
1870                                        _remove_output_observer(
1871                                            node, model, named_modules
1872                                        )
1873
1874                                if qhandler is not None and qhandler.is_custom_module():
1875                                    if (
1876                                        node.target
1877                                        not in custom_module_names_already_swapped
1878                                    ):
1879                                        custom_module_names_already_swapped.add(
1880                                            node.target
1881                                        )
1882                                        _swap_custom_module_to_observed(
1883                                            node,
1884                                            qconfig,
1885                                            named_modules,
1886                                            prepare_custom_config,
1887                                        )
1888
1889                else:  # output
1890                    _maybe_insert_observers_before_graph_output(
1891                        node, model, named_modules, model.graph, obs_or_fq_map, is_qat
1892                    )
1893
1894        #
1895        # After this point, the current node has input and output observers
1896        # that it needs for itself inserted.
1897        #
1898
1899        # increment the counters, so future inputs and outputs are assigned
1900        # correct dtypes
1901        if node.op == "placeholder":
1902            inputs_seen_counter += 1
1903        elif node.op == "output":
1904            outputs_seen_counter += 1
1905            results_node = node
1906
1907    return results_node
1908
1909
1910def _run_prepare_fx_on_standalone_modules(
1911    model: torch.nn.Module,
1912    is_qat: bool,
1913    named_modules: Dict[str, torch.nn.Module],
1914    node_name_to_match_result_with_qconfig: Any,
1915    prepare_custom_config: PrepareCustomConfig,
1916    backend_config: BackendConfig,
1917) -> None:
1918    """
1919    Runs prepare_fx on each standalone module. Note: this does
1920    not modify the graph, it just replaces the unobserved modules with
1921    their observed versions.
1922    """
1923    for (
1924        root_node,
1925        _,
1926        pattern,
1927        qhandler,
1928        qconfig,
1929    ) in node_name_to_match_result_with_qconfig.values():
1930        if qhandler is None:
1931            continue
1932        elif not qhandler.is_standalone_module():
1933            continue
1934
1935        (
1936            sm_qconfig_mapping,
1937            sm_example_inputs,
1938            sm_prepare_custom_config,
1939            sm_backend_config,
1940        ) = _get_standalone_module_configs(
1941            root_node, named_modules, prepare_custom_config, qconfig, backend_config
1942        )
1943
1944        standalone_module = named_modules[root_node.target]
1945        prepare = (
1946            torch.ao.quantization.quantize_fx._prepare_standalone_module_fx
1947        )  # type: ignore[attr-defined]
1948        observed_standalone_module = prepare(
1949            standalone_module,
1950            sm_qconfig_mapping,
1951            is_qat,
1952            example_inputs=sm_example_inputs,
1953            prepare_custom_config=sm_prepare_custom_config,
1954            backend_config=sm_backend_config,
1955        )
1956        parent_name, name = _parent_name(root_node.target)
1957        setattr(named_modules[parent_name], name, observed_standalone_module)
1958        named_modules[root_node.target] = observed_standalone_module
1959
1960
1961def _save_state(
1962    observed: GraphModule,
1963    node_name_to_qconfig: Dict[str, QConfigAny],
1964    node_name_to_scope: Dict[str, Tuple[str, type]],
1965    prepare_custom_config: PrepareCustomConfig,
1966    equalization_node_name_to_qconfig: Dict[str, Any],
1967    qconfig_mapping: QConfigMapping,
1968    is_qat: bool,
1969    observed_node_names: Set[str],
1970) -> None:
1971    observed.meta["_observed_graph_module_attrs"] = ObservedGraphModuleAttrs(
1972        node_name_to_qconfig=node_name_to_qconfig,
1973        node_name_to_scope=node_name_to_scope,
1974        prepare_custom_config=prepare_custom_config,
1975        equalization_node_name_to_qconfig=equalization_node_name_to_qconfig,
1976        qconfig_mapping=qconfig_mapping,
1977        is_qat=is_qat,
1978        observed_node_names=observed_node_names,
1979    )
1980
1981
1982def prepare(
1983    model: GraphModule,
1984    qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
1985    is_qat: bool,
1986    node_name_to_scope: Dict[str, Tuple[str, type]],
1987    example_inputs: Tuple[Any, ...],
1988    prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
1989    _equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None,
1990    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
1991    is_standalone_module: bool = False,
1992) -> GraphModule:
1993    """standalone_module means it a submodule that is not inlined in
1994    parent module, and will be quantized separately as one unit.
1995
1996    How the standalone module is observed is specified by `input_quantized_idxs` and
1997    `output_quantized_idxs` in the prepare_custom_config for the standalone module
1998    Args:
1999        node_name_to_scope: mapping from node name to the scope of the module which contains the node.
2000        The scope is a tuple of fully qualified path of the module and the type of the module
2001    Returns:
2002        model(GraphModule): prepared standalone module
2003        attributes related to standalone module
2004        in model.meta["_observed_graph_module_attrs"]:
2005            is_observed_standalone_module (bool): boolean value that shows whether the
2006            current model is a observed standalone module or not
2007            standalone_module_input_quantized_idxs(List[Int]): a list of
2008                indexes for the graph input that is expected to be quantized,
2009                same as input_quantized_idxs configuration provided
2010                for the standalone module
2011            standalone_module_output_quantized_idxs(List[Int]): a list of
2012                indexs for the graph output that is quantized
2013                same as input_quantized_idxs configuration provided
2014                for the standalone module
2015    """
2016    if prepare_custom_config is None:
2017        prepare_custom_config = PrepareCustomConfig()
2018    if _equalization_config is None:
2019        _equalization_config = QConfigMapping()
2020
2021    if isinstance(qconfig_mapping, dict):
2022        warnings.warn(
2023            "Passing a QConfig dictionary to prepare is deprecated and will not be supported "
2024            "in a future version. Please pass in a QConfigMapping instead.",
2025            FutureWarning,
2026            stacklevel=2,
2027        )
2028        qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
2029
2030    if isinstance(_equalization_config, dict):
2031        warnings.warn(
2032            "Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
2033            "be supported in a future version. Please pass in a QConfigMapping instead.",
2034            FutureWarning,
2035            stacklevel=2,
2036        )
2037        _equalization_config = QConfigMapping.from_dict(_equalization_config)
2038
2039    if isinstance(prepare_custom_config, dict):
2040        warnings.warn(
2041            "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
2042            "in a future version. Please pass in a PrepareCustomConfig instead.",
2043            FutureWarning,
2044            stacklevel=2,
2045        )
2046        prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
2047
2048    if isinstance(backend_config, dict):
2049        warnings.warn(
2050            "Passing a backend_config_dict to prepare is deprecated and will not be supported "
2051            "in a future version. Please pass in a BackendConfig instead.",
2052            FutureWarning,
2053            stacklevel=2,
2054        )
2055        backend_config = BackendConfig.from_dict(backend_config)
2056
2057    assert isinstance(qconfig_mapping, QConfigMapping)
2058    assert isinstance(_equalization_config, QConfigMapping)
2059    qconfig_mapping = copy.deepcopy(qconfig_mapping)
2060    _equalization_config = copy.deepcopy(_equalization_config)
2061
2062    # mapping from a tuple of nodes in reverse order to uninitialized
2063    #   QuantizeHandler subclass. For example,
2064    # {
2065    #   # match a single node
2066    #   (<class 'torch.nn.modules.conv.Conv3d'>:
2067    #     <class 'torch.ao.quantization.fx.quantize.ConvRelu'>),
2068    #   # match multiple nodes in reverse order
2069    #   ((<function relu at 0x7f766a7360d0>, <built-in function add>):
2070    #     <class 'torch.ao.quantization.fx.quantize.Add'>),
2071    # }
2072
2073    pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {}
2074    if backend_config is None:
2075        backend_config = get_native_backend_config()
2076    pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
2077    pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler)
2078
2079    root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
2080
2081    _update_qconfig_for_fusion(model, qconfig_mapping)
2082    _update_qconfig_for_fusion(model, _equalization_config)
2083    flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
2084    # TODO: support regex as well
2085    propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
2086
2087    if is_qat:
2088        module_to_qat_module = get_module_to_qat_module(backend_config)
2089        _qat_swap_modules(model, module_to_qat_module)
2090        _update_qconfig_for_qat(qconfig_mapping, backend_config)
2091
2092    # mapping from fully qualified module name to module instance
2093    # for example,
2094    # {
2095    #   '': Model(...),
2096    #   'linear': Linear(...),
2097    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
2098    # }
2099    named_modules = dict(model.named_modules(remove_duplicate=False))
2100
2101    # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
2102    equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
2103        model, named_modules, model.graph, _equalization_config, node_name_to_scope
2104    )
2105    node_name_to_qconfig = _generate_node_name_to_qconfig(
2106        model, named_modules, model.graph, qconfig_mapping, node_name_to_scope
2107    )
2108
2109    # match the patterns that will get quantized
2110    standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())
2111    standalone_module_classes = list(
2112        prepare_custom_config.standalone_module_classes.keys()
2113    )
2114
2115    custom_module_classes = get_custom_module_class_keys(
2116        prepare_custom_config.float_to_observed_mapping
2117    )
2118    matches_without_qconfig = _find_matches(
2119        model.graph,
2120        named_modules,
2121        pattern_to_quantize_handler,
2122        root_node_getter_mapping,
2123        standalone_module_names,
2124        standalone_module_classes,
2125        custom_module_classes,
2126    )
2127
2128    # map qconfig instances to matches
2129    node_name_to_match_result_with_qconfig = {}
2130    for node_name, match_without_qconfig in matches_without_qconfig.items():
2131        match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name])
2132        node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig
2133
2134    _run_prepare_fx_on_standalone_modules(
2135        model,
2136        is_qat,
2137        named_modules,
2138        node_name_to_match_result_with_qconfig,
2139        prepare_custom_config,
2140        backend_config,
2141    )
2142
2143    # record names for the set of observed node, so that in convert step
2144    # we know whether we need to convert a floating point module to reference
2145    # quantized module or not
2146    observed_node_names: Set[str] = set()
2147
2148    result_node = insert_observers_for_model(
2149        model,
2150        node_name_to_match_result_with_qconfig,
2151        node_name_to_qconfig,
2152        prepare_custom_config,
2153        equalization_node_name_to_qconfig,
2154        backend_config,
2155        observed_node_names,
2156        is_qat,
2157    )
2158    model = GraphModule(model, model.graph)
2159
2160    _save_state(
2161        model,
2162        node_name_to_qconfig,
2163        node_name_to_scope,
2164        prepare_custom_config,
2165        equalization_node_name_to_qconfig,
2166        qconfig_mapping,
2167        is_qat,
2168        observed_node_names,
2169    )
2170
2171    if is_standalone_module:
2172        assert result_node is not None
2173        assert isinstance(result_node.args[0], Node), (
2174            "standalone module only supports returning simple value currently"
2175            "(not tuple, dict etc.)"
2176        )
2177        # these inputs are observed in parent
2178        # converting List[int] to Tensor since module attribute is
2179        # Union[Tensor, Module]
2180        input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
2181        output_quantized_idxs: List[
2182            int
2183        ] = prepare_custom_config.output_quantized_indexes
2184        observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
2185        # inplace modification
2186        observed_graph_module_attrs.is_observed_standalone_module = True
2187        observed_graph_module_attrs.standalone_module_input_quantized_idxs = (
2188            input_quantized_idxs
2189        )
2190        observed_graph_module_attrs.standalone_module_output_quantized_idxs = (
2191            output_quantized_idxs
2192        )
2193    return model
2194