xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/qconfig_mapping_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import re
3from collections import defaultdict, OrderedDict
4from typing import Any, Callable, Dict, List, Set, Tuple, Union
5
6import torch
7from torch.ao.nn.intrinsic import _FusedModule
8from torch.ao.quantization import QConfig
9from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig
10from torch.ao.quantization.backend_config.utils import get_module_to_qat_module
11from torch.ao.quantization.observer import _is_activation_post_process
12from torch.ao.quantization.qconfig import (
13    _add_module_to_qconfig_obs_ctr,
14    qconfig_equals,
15    QConfigAny,
16)
17from torch.ao.quantization.qconfig_mapping import (
18    _MODULE_NAME_DICT_KEY,
19    _MODULE_NAME_REGEX_DICT_KEY,
20    _OBJECT_TYPE_DICT_KEY,
21    QConfigMapping,
22)
23from torch.ao.quantization.utils import _parent_name, get_qconfig_dtypes
24from torch.fx import GraphModule
25from torch.fx.graph import Graph
26
27
28__all__: List[str] = []
29
30
31def _maybe_adjust_qconfig_for_module_name_object_type_order(
32    qconfig_mapping: QConfigMapping,
33    cur_module_path: str,
34    cur_object_type: Callable,
35    cur_object_type_idx: int,
36    fallback_qconfig: QConfigAny,
37) -> QConfigAny:
38    for (
39        module_name,
40        object_type,
41        index,
42    ), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items():
43        if (
44            (module_name == cur_module_path)
45            and (object_type == cur_object_type)
46            and (index == cur_object_type_idx)
47        ):
48            return qconfig
49    return fallback_qconfig
50
51
52def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping):
53    """
54    Update the QConfigMapping to account for fused modules such as LinearReLU.
55    This assumes the QConfigMapping's attributes have already been converted to OrderedDicts.
56    """
57    object_type_dict = qconfig_mapping.object_type_qconfigs
58    if len(object_type_dict) == 0:
59        return qconfig_mapping
60
61    modules = dict(model.named_modules())
62
63    for node in model.graph.nodes:
64        if node.op == "call_module" and node.target in modules:
65            maybe_fused_module = modules[str(node.target)]
66            if not isinstance(maybe_fused_module, _FusedModule):
67                continue
68
69            ops = list(maybe_fused_module._modules.values())
70            fused_qconfig = object_type_dict.get(type(ops[0]), None)
71
72            # Raise an error if the modules in the fused module have
73            # different qconfigs specified in the qconfig_dict
74            # TODO: currently it only works for modules,
75            # need to make this work for torch.nn.functional.relu
76            # TODO: currently it only works for object_type configurations,
77            # ideally it should work for different types of configurations,
78            # maybe we want to redesign this part
79            for op in ops[1:]:
80                if not qconfig_equals(
81                    object_type_dict.get(type(op), None), fused_qconfig
82                ):
83                    raise LookupError(
84                        "During fusion, we need to specify the same "
85                        + f"qconfigs for all module types in {type(maybe_fused_module)} "
86                        + f"offending type: {type(op)}"
87                    )
88
89            if fused_qconfig is not None:
90                object_type_dict[type(maybe_fused_module)] = fused_qconfig
91
92
93def _generate_node_name_to_qconfig(
94    root: torch.nn.Module,
95    modules: Dict[str, torch.nn.Module],
96    input_graph: Graph,
97    qconfig_mapping: QConfigMapping,
98    node_name_to_scope: Dict[str, Tuple[str, type]],
99) -> Dict[str, QConfigAny]:
100    global_qconfig = qconfig_mapping.global_qconfig
101    node_name_to_qconfig = {}
102
103    # example:
104    #
105    #   {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...}
106    #
107    # meaning in submodule 'foo.bar', we have seen 0 F.linear and
108    # 1 F.conv2d invocations so far.
109    submodule_to_object_type_to_cur_idx: Dict[str, Dict[Callable, int]] = defaultdict(
110        lambda: defaultdict(int)
111    )
112    for node in input_graph.nodes:
113        qconfig = None
114        if node.op == "get_attr":
115            module_name, _ = _parent_name(node.target)
116            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
117                qconfig_mapping, type(modules[module_name]), module_name, global_qconfig
118            )
119            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
120                qconfig, modules.get(node.target, None)
121            )
122        elif node.op == "call_function":
123            # precedence: module_name_qconfig
124            # > function_qconfig > global_qconfig
125            # module_name takes precedence over function qconfig
126            function_qconfig = _get_object_type_qconfig(
127                qconfig_mapping, node.target, global_qconfig
128            )
129            module_path, module_type = node_name_to_scope[node.name]
130            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
131                qconfig_mapping, module_type, module_path, function_qconfig
132            )
133
134            cur_object_type_idx = submodule_to_object_type_to_cur_idx[module_path][
135                node.target
136            ]
137            submodule_to_object_type_to_cur_idx[module_path][node.target] += 1
138            qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
139                qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig
140            )
141            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
142                qconfig, modules.get(node.target, None)
143            )
144
145        elif node.op == "call_method":
146            module_path, module_type = node_name_to_scope[node.name]
147            # first use node.target (string) to get the qconfig
148            # this is to support configs like
149            # "object_type": [("reshape", qconfig)]
150            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
151                qconfig_mapping, node.target, module_path, global_qconfig
152            )
153            # if there is no special config for the method, we'll fall back to the
154            # config for the module that contains the call_method node
155            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
156                qconfig_mapping, module_type, module_path, qconfig
157            )
158            # currently call_method does not support modifying qconfig
159            # by order, we can add this later if it is needed.
160            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
161                qconfig, modules.get(node.target, None)
162            )
163
164        elif node.op == "call_module":
165            # if the node is an observer, just continue - don't add it to the qconfig_map
166            if _is_activation_post_process(modules[node.target]):
167                continue
168            qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
169                qconfig_mapping, type(modules[node.target]), node.target, global_qconfig
170            )
171
172            module_path, module_type = node_name_to_scope[node.name]
173            # Note: for call_module, the module_path is the current module's name.
174            # to meaningfully count invocations, we need to count them in the parent
175            # module.
176            parent_name, _ = _parent_name(module_path)
177            cur_object_type_idx = submodule_to_object_type_to_cur_idx[parent_name][
178                module_type
179            ]
180            submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1
181            qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
182                qconfig_mapping, parent_name, module_type, cur_object_type_idx, qconfig
183            )
184            qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
185                qconfig, modules.get(node.target, None)
186            )
187
188            # regex is not supported eager mode propagate_qconfig_, we'll
189            # need to set the qconfig explicitly here in case regex
190            # is used
191            modules[node.target].qconfig = qconfig_with_device_check
192        else:
193            qconfig_with_device_check = None
194
195        node_name_to_qconfig[node.name] = qconfig_with_device_check
196    return node_name_to_qconfig
197
198
199def _check_is_valid_config_dict(
200    config_dict: Any, allowed_keys: Set[str], dict_name: str
201) -> None:
202    r"""Checks if the given config_dict has the correct keys
203
204    Args:
205      `config_dict`: dictionary whose keys we want to check
206    """
207
208    for k in config_dict.keys():
209        if k not in allowed_keys:
210            raise ValueError(
211                "Expected "
212                + dict_name
213                + " to have the following keys: "
214                + str(allowed_keys)
215                + ". But found '"
216                + k
217                + "' instead."
218            )
219
220
221def _compare_prepare_convert_qconfig_mappings(
222    prepare_qconfig_mapping: QConfigMapping, convert_qconfig_mapping: QConfigMapping
223):
224    r"""Compare the qconfig_mapping passed in convert to the one from prepare and check the values
225
226    Args:
227      `prepare_qconfig_mapping`: configuration for prepare quantization step
228      `convert_qconfig_mapping`: configuration for convert quantization step
229    """
230    assert qconfig_equals(
231        prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig
232    ), "Expected global qconfigs to be the same in the prepare and convert quantization configs"
233    prepare_dicts: List[OrderedDict] = [
234        prepare_qconfig_mapping.object_type_qconfigs,
235        prepare_qconfig_mapping.module_name_qconfigs,
236        prepare_qconfig_mapping.module_name_regex_qconfigs,
237    ]
238    convert_dicts: List[OrderedDict] = [
239        convert_qconfig_mapping.object_type_qconfigs,
240        convert_qconfig_mapping.module_name_qconfigs,
241        convert_qconfig_mapping.module_name_regex_qconfigs,
242    ]
243    dict_names = [
244        _OBJECT_TYPE_DICT_KEY,
245        _MODULE_NAME_DICT_KEY,
246        _MODULE_NAME_REGEX_DICT_KEY,
247    ]
248    for i in range(len(prepare_dicts)):
249        for name in prepare_dicts[i].keys():
250            assert (
251                name in convert_dicts[i]
252            ), f"Missing key {dict_names[i]} {name} in convert QConfigMapping \
253                when it was present in prepare"
254            assert convert_dicts[i][name] is None or qconfig_equals(
255                prepare_dicts[i][name], convert_dicts[i][name]
256            ), f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \
257                prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}"
258
259
260def _is_qconfig_supported_by_dtype_configs(
261    qconfig: QConfig, dtype_configs: List[DTypeConfig]
262):
263    for dtype_config in dtype_configs:
264        is_dynamic = dtype_config.is_dynamic
265        if is_dynamic is None:
266            is_dynamic = False
267        input_dtype = dtype_config.input_dtype or torch.float
268        weight_dtype = dtype_config.weight_dtype or torch.float
269        bias_dtype = dtype_config.bias_dtype or torch.float
270        output_dtype = dtype_config.output_dtype or torch.float
271        (
272            qconfig_activation_dtype,
273            qconfig_weight_dtype,
274            qconfig_input_act_is_dynamic,
275        ) = get_qconfig_dtypes(qconfig)
276        qconfig_bias_dtype = (
277            torch.float16
278            if (
279                qconfig_activation_dtype == torch.float16
280                and qconfig_weight_dtype == torch.float16
281                and not is_dynamic
282            )
283            else torch.float
284        )
285
286        if is_dynamic:
287            is_match = (
288                qconfig_input_act_is_dynamic
289                and input_dtype == qconfig_activation_dtype
290                and output_dtype == torch.float
291                and weight_dtype == qconfig_weight_dtype
292            )
293        else:
294            is_match = (
295                input_dtype == qconfig_activation_dtype
296                and output_dtype == qconfig_activation_dtype
297                and weight_dtype == qconfig_weight_dtype
298                and bias_dtype == qconfig_bias_dtype
299            )
300        if is_match:
301            return True
302    return False
303
304
305def _get_object_type_qconfig(
306    qconfig_mapping: QConfigMapping,
307    object_type: Union[Callable, str],
308    fallback_qconfig: QConfigAny,
309) -> QConfigAny:
310    return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig)
311
312
313def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
314    for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items():
315        if re.match(regex_pattern, module_name):
316            # first match wins
317            return qconfig
318    return fallback_qconfig
319
320
321def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
322    if module_name == "":
323        # module name qconfig not found
324        return fallback_qconfig
325    if module_name in qconfig_mapping.module_name_qconfigs:
326        return qconfig_mapping.module_name_qconfigs[module_name]
327    else:
328        parent, _ = _parent_name(module_name)
329        return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
330
331
332def _maybe_adjust_qconfig_for_module_type_or_name(
333    qconfig_mapping, module_type, module_name, global_qconfig
334):
335    # get qconfig for module_name,
336    # fallback to module_name_regex_qconfig, module_type_qconfig,
337    # global_qconfig if necessary
338    module_type_qconfig = _get_object_type_qconfig(
339        qconfig_mapping, module_type, global_qconfig
340    )
341    module_name_regex_qconfig = _get_module_name_regex_qconfig(
342        qconfig_mapping, module_name, module_type_qconfig
343    )
344    module_name_qconfig = _get_module_name_qconfig(
345        qconfig_mapping, module_name, module_name_regex_qconfig
346    )
347    return module_name_qconfig
348
349
350def _get_flattened_qconfig_dict(
351    qconfig_mapping: QConfigMapping,
352) -> Dict[Union[Callable, str], QConfigAny]:
353    """flatten the global, object_type and module_name qconfig
354    to the same qconfig_dict so that it can be used by
355    propagate_qconfig_ function.
356    "module_name_regex" is ignored for now since it's not supported
357    in propagate_qconfig_, but it can be fixed later.
358
359    For example:
360    Input: {
361      "": qconfig,
362      "object_type": [
363        (torch.add, qconfig)
364      ],
365      "module_name": [
366        ("conv", qconfig)
367      ]
368    }
369
370    Output: {
371      "": qconfig,
372      torch.add: qconfig,
373      "conv": qconfig
374    }
375    """
376    flattened: Dict[Union[Callable, str], QConfigAny] = {
377        "": qconfig_mapping.global_qconfig
378    }
379    for obj, qconfig in qconfig_mapping.object_type_qconfigs.items():
380        flattened[obj] = qconfig
381    for obj, qconfig in qconfig_mapping.module_name_qconfigs.items():
382        flattened[obj] = qconfig
383    return flattened
384
385
386def _update_qconfig_for_qat(
387    qconfig_mapping: QConfigMapping, backend_config: BackendConfig
388):
389    """
390    Update the qconfig_mapping to account for module swaps during QAT.
391    During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types.
392    """
393    module_to_qat_module_class = get_module_to_qat_module(backend_config)
394    object_type_dict = qconfig_mapping.object_type_qconfigs
395    new_object_type_dict = object_type_dict.copy()
396    for k, v in new_object_type_dict.items():
397        if k in module_to_qat_module_class:
398            object_type_dict[module_to_qat_module_class[k]] = v
399