xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/backend_config/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Callable, Dict, List, Tuple, Type, Union
3
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7from torch.ao.quantization.fuser_method_mappings import _reverse2, _reverse3
8from torch.ao.quantization.utils import Pattern
9
10from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig
11
12
13__all__ = [
14    "get_pattern_to_dtype_configs",
15    "get_qat_module_classes",
16    "get_fused_module_classes",
17    "get_pattern_to_input_type_to_index",
18    "get_root_module_to_quantized_reference_module",
19    "get_fuser_method_mapping",
20    "get_module_to_qat_module",
21    "get_fusion_pattern_to_root_node_getter",
22    "get_fusion_pattern_to_extra_inputs_getter",
23    "remove_boolean_dispatch_from_name",
24    "pattern_to_human_readable",
25    "entry_to_pretty_str",
26]
27
28
29def get_pattern_to_dtype_configs(
30    backend_config: BackendConfig,
31) -> Dict[Pattern, List[DTypeConfig]]:
32    pattern_to_dtype_configs: Dict[Pattern, List[DTypeConfig]] = {}
33    for pattern, config in backend_config._pattern_complex_format_to_config.items():
34        pattern_to_dtype_configs[pattern] = config.dtype_configs
35    return pattern_to_dtype_configs
36
37
38def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
39    qat_module_classes = []
40    for config in backend_config.configs:
41        if config.qat_module is not None:
42            qat_module_classes.append(config.qat_module)
43    return tuple(set(qat_module_classes))
44
45
46def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
47    fused_module_classes = []
48    for config in backend_config.configs:
49        if config.fused_module is not None:
50            fused_module_classes.append(config.fused_module)
51    return tuple(set(fused_module_classes))
52
53
54def get_pattern_to_input_type_to_index(
55    backend_config: BackendConfig,
56) -> Dict[Pattern, Dict[str, int]]:
57    pattern_to_input_type_to_index: Dict[Pattern, Dict[str, int]] = {}
58    for pattern, config in backend_config._pattern_complex_format_to_config.items():
59        pattern_to_input_type_to_index[pattern] = config._input_type_to_index
60    return pattern_to_input_type_to_index
61
62
63def get_root_module_to_quantized_reference_module(
64    backend_config: BackendConfig,
65) -> Dict[Type[torch.nn.Module], Type[torch.nn.Module]]:
66    mapping: Dict[Type[torch.nn.Module], Type[torch.nn.Module]] = {}
67    for config in backend_config.configs:
68        if (
69            config.root_module is not None
70            and config.reference_quantized_module is not None
71        ):
72            mapping[config.root_module] = config.reference_quantized_module
73    return mapping
74
75
76def get_fuser_method_mapping(
77    backend_config: BackendConfig,
78) -> Dict[Pattern, Union[nn.Sequential, Callable]]:
79    fuser_method_mapping: Dict[Pattern, Union[nn.Sequential, Callable]] = {}
80    for pattern, config in backend_config._pattern_complex_format_to_config.items():
81        if config.fuser_method is not None:
82            # Note: both the fuser method and the pattern are specified in forward order in the
83            # BackendConfig, but the internal pattern matching code uses the reversed nested tuple
84            # format, so we need to convert both to the internal format
85            fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config)
86            fuser_method_mapping[pattern] = fuser_method
87    return fuser_method_mapping
88
89
90def get_module_to_qat_module(
91    backend_config: BackendConfig,
92) -> Dict[Pattern, Type[torch.nn.Module]]:
93    module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]] = {}
94    for pattern, config in backend_config._pattern_complex_format_to_config.items():
95        if config.qat_module is not None:
96            module_to_qat_module[pattern] = config.qat_module
97    return module_to_qat_module
98
99
100def get_fusion_pattern_to_root_node_getter(
101    backend_config: BackendConfig,
102) -> Dict[Pattern, Callable]:
103    """Get a map from fusion pattern to a function that returns the root node
104    from the fusion pattern, e.g. the most common one is:
105    def get_root_node(node_pattern):
106        while not isinstance(node_pattern[-1], Node):
107            node_pattern = node_pattern[-1]
108        return node_pattern[-1]
109    This can work for all patterns whose root node is the "last node" in the pattern,
110    e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d))
111    """
112    root_node_getter_mapping: Dict[Pattern, Callable] = {}
113    for pattern, config in backend_config._pattern_complex_format_to_config.items():
114        if config._root_node_getter is not None:
115            root_node_getter_mapping[pattern] = config._root_node_getter
116    return root_node_getter_mapping
117
118
119def get_fusion_pattern_to_extra_inputs_getter(
120    backend_config: BackendConfig,
121) -> Dict[Pattern, Callable]:
122    """Get a map from fusion pattern to a function that returns extra input nodes
123    from the fusion pattern, in the order required by the root node. This is optional,
124    if not specified, we will not copy over any extra inputs for the root node.
125    Example:
126    # Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d))
127    # and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra
128    # argument to the fused module, we can unpack the pattern and return the node at
129    # MatchAllNode here
130    # we can implement extra_inputs_getter as follows:
131    def extra_inputs_getter(pattern) -> List[Any]:
132        add, extra_input, conv_pattern = pattern
133        return [extra_input]
134    """
135    extra_inputs_getter_mapping: Dict[Pattern, Callable] = {}
136    for pattern, config in backend_config._pattern_complex_format_to_config.items():
137        if config._extra_inputs_getter is not None:
138            extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter
139    return extra_inputs_getter_mapping
140
141
142def remove_boolean_dispatch_from_name(p) -> Any:
143    """
144    Some ops have a default string representation such as
145    '<function boolean_dispatch.<locals>.fn at 0x7ff1106bf280>',
146    this function replaces them with the hardcoded function names.
147    """
148    if p is F.fractional_max_pool2d:
149        return "torch.nn.functional.fractional_max_pool2d"
150    elif p is F.fractional_max_pool3d:
151        return "torch.nn.functional.fractional_max_pool3d"
152    elif p is F.max_pool1d:
153        return "torch.nn.functional.max_pool1d"
154    elif p is F.max_pool2d:
155        return "torch.nn.functional.max_pool2d"
156    elif p is F.max_pool3d:
157        return "torch.nn.functional.max_pool3d"
158    elif p is F.adaptive_max_pool1d:
159        return "torch.nn.functional.adaptive_max_pool1d"
160    elif p is F.adaptive_max_pool2d:
161        return "torch.nn.functional.adaptive_max_pool2d"
162    elif p is F.adaptive_max_pool3d:
163        return "torch.nn.functional.adaptive_max_pool3d"
164    assert "boolean_dispatch" not in str(p), (
165        f"{p} does not have a human readable representation in "
166        + "quantization documentation"
167    )
168    return p
169
170
171def pattern_to_human_readable(p) -> Any:
172    if isinstance(p, tuple):
173        # nested patterns, recurse
174        return tuple(pattern_to_human_readable(inner_p) for inner_p in p)
175    elif isinstance(p, str):
176        # method names are already human readable
177        return p
178    else:
179        p = remove_boolean_dispatch_from_name(p)
180        return p
181
182
183# TODO(future PR): move backend_config_dict to use dataclass and move this logic to
184# the corresponding __str__ function
185def entry_to_pretty_str(entry) -> str:
186    """
187    Given a backend_config_dict entry, returns a string with the human readable
188    representation of it.
189    """
190    s = "{\n"
191
192    # always output the pattern first
193    if "pattern" in entry:
194        pattern_str = pattern_to_human_readable(entry["pattern"])
195
196        s += f"  'pattern': {pattern_str},\n"
197
198    # custom output for dtype_configs to make it look nice
199    if "dtype_configs" in entry:
200        s += "  'dtype_configs': [\n"
201        for dtype_config in entry["dtype_configs"]:
202            s += "    {\n"
203            for k, v in dtype_config.items():
204                s += f"      '{k}': {v},\n"
205            s += "    },\n"
206        s += "  ],\n"
207
208    # custom output for num_tensor_args_to_observation_type to make it look nice
209    if "num_tensor_args_to_observation_type" in entry:
210        s += "  'num_tensor_args_to_observation_type': {\n"
211        for k, v in entry["num_tensor_args_to_observation_type"].items():
212            s += f"    {k}: {v},\n"
213        s += "  },\n"
214
215    # output all the other fields
216    custom_handled_fields = [
217        "pattern",
218        "dtype_configs",
219        "num_tensor_args_to_observation_type",
220    ]
221    for field_name in entry:
222        if field_name in custom_handled_fields:
223            continue
224        s += f"  '{field_name}': {entry[field_name]},\n"
225
226    s += "}"
227    return s
228
229
230def _get_pattern_in_reversed_nested_tuple_format(
231    config: BackendPatternConfig,
232) -> Pattern:
233    """
234    Return the pattern specified in the given config in the reversed nested tuple format
235    used internally in the quantization pattern matching code.
236
237    If the pattern is not a tuple, or the pattern is already specified in the reversed
238    nested tuple format, return the pattern as is. Otherwise:
239
240    For 2-tuples (a, b), return (b, a).
241    For 3-tuples (a, b, c), return (c, (b, a)).
242
243    For example:
244        * Given nn.Linear, return nn.Linear
245        * Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear)
246        * Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return
247          (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
248
249    For context, the reason why this is needed is the user-facing BackendConfig
250    API accepts the flat 2-or-3-tuple format in forward order. While this simple
251    format handles the vast majority of use cases, it does not handle the more
252    complex ones, and so the internal pattern matching code for quantization uses
253    the following, more general reversed nested tuple format instead:
254
255        operator = module_type | functional | torch op | native op | MatchAllNode
256        Pattern = (operator, Pattern, Pattern, ...) | operator
257
258    In the future, we expect to replace the above complex format with the one used
259    by the subgraph rewriter in torch.fx, so we don't have to maintain our own
260    complex pattern matching code. Then we won't need this helper function anymore.
261    """
262    if config._pattern_complex_format is not None:
263        return config._pattern_complex_format
264    if config.pattern is None:
265        raise ValueError(
266            "Either 'pattern' or 'pattern_complex_format' must be specified"
267        )
268    if not isinstance(config.pattern, tuple):
269        return config.pattern
270
271    # Pattern is specified in the simple tuple format, need to convert
272    if len(config.pattern) == 2:
273        (a, b) = config.pattern
274        return (b, a)
275    elif len(config.pattern) == 3:
276        (a, b, c) = config.pattern
277        return (c, (b, a))
278    else:
279        raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
280
281
282def _get_fuser_method_in_reversed_nested_tuple_format(
283    config: BackendPatternConfig,
284) -> Callable:
285    """
286    Return the fuser method specified in the given config in the reversed nested
287    tuple format used internally in the quantization pattern matching code.
288
289    If pattern is specified in the reversed nested tuple format, we assume the
290    fuser method is also specified in this format and simply return it as is.
291    Otherwise, we convert the fuser method as follows:
292
293        * Given f(is_qat, conv, relu), return f'(is_qat, relu, conv)
294        * Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv),
295          where bn_conv is a 2-tuple (bn, conv)
296
297    The first argument of a fuser method is always `is_qat` and is not affected
298    in the conversion. We currently only support functions with 3 or 4 arguments.
299    """
300    assert config.fuser_method is not None
301    if config._pattern_complex_format is not None:
302        return config.fuser_method
303    if not isinstance(config.pattern, tuple):
304        raise ValueError("Expected pattern to be a tuple, got: ", config.pattern)
305
306    # Pattern is specified in the simple tuple format, need to convert
307    if len(config.pattern) == 2:
308        return _reverse2(config.fuser_method)
309    elif len(config.pattern) == 3:
310        return _reverse3(config.fuser_method)
311    else:
312        raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
313