xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/pattern_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3from collections import OrderedDict
4from typing import Any, Dict
5
6from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
7from torch.ao.quantization.observer import ObserverBase
8from torch.ao.quantization.utils import Pattern
9
10
11__all__ = [
12    "get_default_fusion_patterns",
13    "get_default_quant_patterns",
14    "get_default_output_activation_post_process_map",
15]
16
17# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
18QuantizeHandler = Any
19
20# pattern for conv bn fusion
21_DEFAULT_FUSION_PATTERNS: Dict[Pattern, QuantizeHandler] = OrderedDict()
22
23
24def _register_fusion_pattern(pattern):
25    def insert(fn):
26        _DEFAULT_FUSION_PATTERNS[pattern] = fn
27        return fn
28
29    return insert
30
31
32def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
33    return copy.copy(_DEFAULT_FUSION_PATTERNS)
34
35
36_DEFAULT_QUANTIZATION_PATTERNS: Dict[Pattern, QuantizeHandler] = OrderedDict()
37
38# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation
39# e.g. pattern: torch.sigmoid,
40#      output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant
41_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP: Dict[Pattern, QuantizeHandler] = {}
42_DEFAULT_OUTPUT_OBSERVER_MAP: Dict[Pattern, QuantizeHandler] = {}
43
44
45# Register pattern for both static quantization and qat
46def _register_quant_pattern(pattern, fixed_qparams_observer=None):
47    def insert(fn):
48        _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
49        if fixed_qparams_observer is not None:
50            _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[
51                pattern
52            ] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer)
53            _DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer
54        return fn
55
56    return insert
57
58
59# Get patterns for both static quantization and qat
60def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
61    return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS)
62
63
64# a map from pattern to output activation post process constructor
65# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
66def get_default_output_activation_post_process_map(
67    is_training,
68) -> Dict[Pattern, ObserverBase]:
69    if is_training:
70        return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP)
71    else:
72        return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP)
73
74
75# Example use of register pattern function:
76# @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
77# class ConvOrLinearBNReLUFusion():
78#     def __init__(...):
79#         ...
80#
81
82
83def _sorted_patterns_dict(
84    patterns_dict: Dict[Pattern, QuantizeHandler]
85) -> Dict[Pattern, QuantizeHandler]:
86    """
87    Return a sorted version of the patterns dictionary such that longer patterns are matched first,
88    e.g. match (F.relu, F.linear) before F.relu.
89    This works for current use cases, but we may need to have a more clever way to sort
90    things to address more complex patterns
91    """
92
93    def get_len(pattern):
94        """this will calculate the length of the pattern by counting all the entries
95        in the pattern.
96        this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before
97        (nn.BatchNorm, nn.Conv2d) so that we can match the former first
98        """
99        len = 0
100        if isinstance(pattern, tuple):
101            for item in pattern:
102                len += get_len(item)
103        else:
104            len += 1
105        return len
106
107    return OrderedDict(
108        sorted(
109            patterns_dict.items(),
110            key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1,
111        )
112    )
113