xref: /aosp_15_r20/external/pytorch/torch/ao/ns/fx/pattern_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Any, Callable, Dict, List, Set, Tuple, Union
2
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6from torch.ao.quantization import FakeQuantizeBase, ObserverBase
7from torch.ao.quantization.backend_config import get_native_backend_config
8from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
9from torch.ao.quantization.utils import getattr_from_fqn
10from torch.fx import GraphModule
11from torch.fx.graph import Node
12
13from .ns_types import NSNodeTargetType
14
15
16toq = torch.ops.quantized
17
18
19def get_type_a_related_to_b(
20    base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
21) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]:
22    # TODO(future PR): allow customizations
23    # TODO(future PR): reuse existing quantization mappings
24    # TODO(future PR): add the rest of modules and ops here
25    type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set()
26
27    for s in base_name_to_sets_of_related_ops.values():
28        s_list = list(s)
29        # add every bidirectional pair
30        for idx_0 in range(0, len(s_list)):
31            for idx_1 in range(idx_0, len(s_list)):
32                type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
33                type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
34
35    return type_a_related_to_b
36
37
38NSFusionElType = Union[
39    Callable,  # call_function or call_module type, example: F.linear or nn.Conv2d
40    str,  # call_method name, example: "dequantize"
41    Tuple[
42        str, Any
43    ],  # call_method name and first argument, example: ("to", torch.float16)
44]
45NSFusionType = Union[
46    Tuple[NSFusionElType, NSFusionElType],
47    Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
48]
49
50
51def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
52    """
53    Set of potential fusions, in reverse order.  The order is reversed
54    to match how fusion patterns are defined in quantization code.
55
56    Fusion format:
57    ((fusion_op_0, fusion_op_1), base_op_idx)
58
59    Where base_op_idx is the idx of the op we should use to match other related
60    ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
61    of 0 represents the first op in regular (non-reverse) order, 1 represents the
62    second op, etc.
63    """
64    results: List[Tuple[NSFusionType, int]] = []
65
66    # Possible syntaxes:
67    # * single op: torch.nn.Conv2d
68    # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
69    # For fusions, we only care about patterns composed of multiple ops.
70    # TODO(future PR): allow customizations from default patterns.
71    all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config())
72
73    default_base_op_idx = 0
74    for quant_pattern in all_quant_patterns.keys():
75        # TODO: this is a temporary hack to flatten the patterns from quantization so
76        # that it works with the ns matcher function, maybe we should use `_is_match`
77        # in torch.ao.quantization.fx.match_utils to match the patterns
78        if (
79            isinstance(quant_pattern, tuple)
80            and len(quant_pattern) == 2
81            and isinstance(quant_pattern[1], tuple)
82            and len(quant_pattern[1]) == 2
83        ):
84            # flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
85            quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1])
86
87        # Only patterns of multiple ops are fusions, ignore
88        # patterns which contain a single ops (they get matched
89        # without caring about fusions).
90        if isinstance(quant_pattern, tuple):
91            results.append((quant_pattern, default_base_op_idx))  # type: ignore[arg-type]
92
93        # For each pattern, add additional patterns with observers and
94        # fake quants at the end.
95        # TODO(future PR): if needed, implement matching for a node
96        #   having multiple output observers.
97        for cls in (ObserverBase, FakeQuantizeBase):
98            if isinstance(quant_pattern, tuple):
99                new_pattern = (cls, *quant_pattern)
100            else:
101                new_pattern = (cls, quant_pattern)
102            results.append((new_pattern, default_base_op_idx))  # type: ignore[arg-type]
103
104    # After this point, results contains values such as
105    # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
106
107    # Patterns for matching fp16 emulation are not specified in the quantization
108    # fusion mappings.  For now, define them here.
109    fp16_em_base_op_idx = 1
110    patterns_to_add = [
111        # linear-relu fp16 emulation:
112        # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
113        (
114            (("to", torch.float16), F.relu, F.linear, "dequantize"),
115            fp16_em_base_op_idx,
116        ),
117        # Conv-BN fusion (this happens outside of quantization patterns,
118        # which is why it is defined separately here).
119        ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
120        ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
121        ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
122        ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
123        ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
124        ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
125    ]
126    for p in patterns_to_add:
127        results.append(p)  # type: ignore[arg-type]
128        results.append(((ObserverBase, *p[0]), p[1]))  # type: ignore[arg-type]
129        results.append(((FakeQuantizeBase, *p[0]), p[1]))  # type: ignore[arg-type]
130
131    return results
132
133
134def end_node_matches_reversed_fusion(
135    end_node: Node,
136    reversed_fusion: NSFusionType,
137    gm: GraphModule,
138    seen_nodes: Set[Node],
139) -> bool:
140    """
141    Returns true if a pattern ending with `end_node` matches
142    the fusion pattern.
143    """
144    cur_node = end_node
145    for fusion_idx in range(len(reversed_fusion)):
146        # each node can only belong to one matched pattern
147        if cur_node in seen_nodes:
148            return False
149
150        cur_fusion_el = reversed_fusion[fusion_idx]
151
152        if cur_node.op == "call_function":
153            fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and (
154                not isinstance(cur_fusion_el, type)
155            )
156            if fusion_el_is_fun:
157                if cur_node.target != cur_fusion_el:
158                    return False
159                if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
160                    cur_node = cur_node.args[0]
161                else:
162                    return False
163            else:
164                return False
165
166        elif cur_node.op == "call_module":
167            fusion_el_is_mod = isinstance(cur_fusion_el, type)
168            if fusion_el_is_mod:
169                assert isinstance(cur_node.target, str)
170                target_mod = getattr_from_fqn(gm, cur_node.target)
171                if not isinstance(cur_fusion_el, type):
172                    return False
173                if not isinstance(target_mod, cur_fusion_el):
174                    return False
175                if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
176                    cur_node = cur_node.args[0]
177                else:
178                    return False
179            else:
180                return False
181
182        elif cur_node.op == "call_method":
183            fusion_el_is_meth_with_second_arg = (
184                isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
185            )
186            fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
187            if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
188                if fusion_el_is_meth_without_args:
189                    if cur_node.target != cur_fusion_el:
190                        return False
191                else:
192                    assert isinstance(cur_fusion_el, tuple)
193                    if cur_node.target != cur_fusion_el[0]:
194                        return False
195                    elif len(cur_node.args) < 2:
196                        return False
197                    elif cur_node.args[1] != cur_fusion_el[1]:
198                        return False
199
200                if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
201                    cur_node = cur_node.args[0]
202                else:
203                    return False
204            else:
205                return False
206        else:
207            return False
208
209    return True
210