xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/fuse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3from typing import Any, Callable, Dict, List, Tuple, Union
4
5from torch.ao.quantization.backend_config import (
6    BackendConfig,
7    get_native_backend_config,
8)
9from torch.ao.quantization.backend_config.utils import (
10    get_fuser_method_mapping,
11    get_fusion_pattern_to_extra_inputs_getter,
12    get_fusion_pattern_to_root_node_getter,
13)
14from torch.ao.quantization.utils import NodePattern, Pattern
15from torch.fx import GraphModule, map_arg, Node
16from torch.fx.graph import Graph
17
18from .custom_config import FuseCustomConfig
19from .fuse_handler import _get_fusion_pattern_to_fuse_handler_cls, FuseHandler
20from .match_utils import _is_match, MatchAllNode
21from .pattern_utils import _sorted_patterns_dict
22
23
24__all__ = [
25    "fuse",
26    # TODO: We should make this private in the future
27    # This is currently needed for test_public_bindings for some reason
28    "FuseHandler",
29]
30
31
32def fuse(
33    model: GraphModule,
34    is_qat: bool,
35    fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
36    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
37) -> GraphModule:
38    if fuse_custom_config is None:
39        fuse_custom_config = FuseCustomConfig()
40
41    if isinstance(fuse_custom_config, dict):
42        warnings.warn(
43            "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
44            "in a future version. Please pass in a FuseCustomConfig instead.",
45            FutureWarning,
46            stacklevel=2,
47        )
48        fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
49
50    if isinstance(backend_config, dict):
51        warnings.warn(
52            "Passing a backend_config_dict to prepare is deprecated and will not be supported "
53            "in a future version. Please pass in a BackendConfig instead.",
54            FutureWarning,
55            stacklevel=2,
56        )
57        backend_config = BackendConfig.from_dict(backend_config)
58
59    named_modules = dict(model.named_modules())
60
61    if backend_config is None:
62        backend_config = get_native_backend_config()
63
64    fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(
65        _get_fusion_pattern_to_fuse_handler_cls(backend_config)
66    )
67    fuser_method_mapping = get_fuser_method_mapping(backend_config)
68    fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(
69        backend_config
70    )
71    fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(
72        backend_config
73    )
74
75    # find fusion
76    fusion_pairs = _find_matches(model, model.graph, fusion_pattern_to_fuse_handler_cls)
77    # TODO: change this to inplace changes to graph, since we no longer construct
78    # new GraphModule anymore
79    fused_graph = Graph()
80    env: Dict[Any, Any] = {}
81
82    def load_arg(a):
83        return map_arg(a, lambda node: env[node.name])
84
85    def default_root_node_getter(node_pattern):
86        while not isinstance(node_pattern[-1], Node):
87            node_pattern = node_pattern[-1]
88        return node_pattern[-1]
89
90    for node in model.graph.nodes:
91        (
92            maybe_last_node,
93            pattern,
94            matched_node_pattern,
95            obj,
96            node_to_subpattern,
97        ) = fusion_pairs.get(node.name, (None, None, None, None, None))
98        # get the corresponding subpattern for the current node
99        if node_to_subpattern is not None:
100            node_subpattern = node_to_subpattern.get(node, None)
101        else:
102            node_subpattern = None
103        if maybe_last_node is node:
104            assert obj is not None
105            root_node_getter = fusion_pattern_to_root_node_getter.get(
106                pattern, default_root_node_getter
107            )
108            root_node = root_node_getter(matched_node_pattern)  # type: ignore[index]
109            extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(
110                pattern, None
111            )
112            extra_inputs = []
113            if extra_inputs_getter is not None:
114                extra_inputs = extra_inputs_getter(matched_node_pattern)
115            # TODO: add validation that root_node is a module and has the same type
116            # as the root_module in the configuration
117            env[node.name] = obj.fuse(
118                load_arg,
119                named_modules,
120                fused_graph,
121                root_node,
122                extra_inputs,
123                matched_node_pattern,  # type: ignore[arg-type]
124                fuse_custom_config,
125                fuser_method_mapping,
126                is_qat,
127            )
128        elif maybe_last_node is None or node_subpattern is MatchAllNode:
129            env[node.name] = fused_graph.node_copy(node, load_arg)
130        # node matched in patterns and is not root is removed here
131
132    model = GraphModule(model, fused_graph)
133    return model
134
135
136def _find_matches(
137    root: GraphModule,
138    graph: Graph,
139    pattern_to_fuse_handler_cls: Dict[Pattern, Callable],
140) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]:
141    modules = dict(root.named_modules())
142    # node name -> (root_node, match_value)
143    match_map: Dict[
144        str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]
145    ] = {}
146    # a map from node to the matched subpattern
147    node_to_subpattern: Dict[Node, Any] = {}
148
149    # TODO: dedup with quantization matching function in match_utils.py
150    def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
151        if isinstance(pattern, tuple):
152            s, *args = pattern
153            current_node_pattern: List[Node] = []
154            apply_match(s, node, match, current_node_pattern, node_to_subpattern)
155            for subpattern, arg in zip(args, node.args):
156                apply_match(
157                    subpattern, arg, match, current_node_pattern, node_to_subpattern
158                )
159            matched_node_pattern.append(tuple(current_node_pattern))
160        else:
161            # the first pattern matches will take precedence
162            if node.name not in match_map:
163                matched_node_pattern.append(node)
164                # MatchAllNode here is actually MatchAllInputNode which should not
165                # be added to match_map
166                if pattern is not MatchAllNode:
167                    node_to_subpattern[node] = pattern
168                    root_node, pattern, handler = match
169                    match_map[node.name] = (
170                        root_node,
171                        pattern,
172                        matched_node_pattern,
173                        handler,
174                        node_to_subpattern,
175                    )
176
177    for node in reversed(graph.nodes):
178        if node.name not in match_map:
179            for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
180                matched_node_pattern: List[Node] = []
181                if _is_match(modules, node, pattern):
182                    apply_match(
183                        pattern,
184                        node,
185                        (node, pattern, fuse_handler_cls(node)),
186                        matched_node_pattern,
187                        node_to_subpattern,
188                    )
189                    break
190
191    return match_map
192