xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/fuse_handler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from abc import ABC, abstractmethod
3from typing import Any, Callable, Dict, List, Union
4
5import torch
6from torch.ao.quantization.backend_config import BackendConfig
7from torch.ao.quantization.fuser_method_mappings import get_fuser_method_new
8from torch.ao.quantization.utils import _parent_name, NodePattern, Pattern
9from torch.fx.graph import Graph, Node
10from torch.nn.utils.parametrize import type_before_parametrizations
11
12from .custom_config import FuseCustomConfig
13from .match_utils import MatchAllNode
14
15
16__all__ = [
17    "DefaultFuseHandler",
18    "FuseHandler",
19]
20
21
22# ----------------------------
23# Fusion Pattern Registrations
24# ----------------------------
25
26
27# Base Pattern Handler
28class FuseHandler(ABC):
29    """Base handler class for the fusion patterns"""
30
31    @abstractmethod
32    def __init__(self, node: Node):
33        pass
34
35    @abstractmethod
36    def fuse(
37        self,
38        load_arg: Callable,
39        named_modules: Dict[str, torch.nn.Module],
40        fused_graph: Graph,
41        root_node: Node,
42        extra_inputs: List[Any],
43        matched_node_pattern: NodePattern,
44        fuse_custom_config: FuseCustomConfig,
45        fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
46        is_qat: bool,
47    ) -> Node:
48        pass
49
50
51class DefaultFuseHandler(FuseHandler):
52    def __init__(self, node: Node):
53        super().__init__(node)  # type:ignore[safe-super]
54
55    def fuse(
56        self,
57        load_arg: Callable,
58        named_modules: Dict[str, torch.nn.Module],
59        fused_graph: Graph,
60        root_node: Node,
61        extra_inputs: List[Any],
62        matched_node_pattern: NodePattern,
63        fuse_custom_config: FuseCustomConfig,
64        fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
65        is_qat: bool,
66    ) -> Node:
67        assert (
68            root_node.op == "call_module"
69        ), "Expecting module node to be a call_module Node"
70        root_module = named_modules[str(root_node.target)]
71
72        def get_modules(pattern):
73            """Given a node pattern, extract the corresponding modules
74            e.g. input: (relu_node, (bn_node, conv_node))
75                 output: (relu_module, (bn_module, conv_module))
76            """
77            if isinstance(pattern, (tuple, list)):
78                n, *args = pattern
79                modules: List[torch.nn.Module] = []
80                modules.append(get_modules(n))
81                for a in args:
82                    modules.append(get_modules(a))
83                return tuple(modules)
84            else:
85                n = pattern
86                if n.op == "call_module":
87                    return named_modules[n.target]
88                elif n.op == "call_function" and n.target == torch.nn.functional.relu:
89                    relu = torch.nn.ReLU()
90                    relu.training = root_module.training
91                    return relu
92                elif n.op == "call_function" or n.op == "call_method":
93                    return n.target
94                else:
95                    return MatchAllNode
96
97        # since relu can be used multiple times, we'll need to create a relu module for each match
98        matched_modules = get_modules(matched_node_pattern)
99
100        def get_matched_types(m):
101            if isinstance(m, tuple):
102                return tuple(map(get_matched_types, m))
103            if isinstance(m, torch.nn.Module):
104                return type_before_parametrizations(m)
105            return m
106
107        matched_module_types = get_matched_types(matched_modules)
108        module_parent_name, module_name = _parent_name(root_node.target)
109        fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
110        # TODO: change the signature for fuser_method to take matched module patterns
111        # as input
112        fused_module = fuser_method(is_qat, *matched_modules)
113        setattr(named_modules[module_parent_name], module_name, fused_module)
114        extra_args = []
115        for input in extra_inputs:
116            extra_args.append(load_arg(input))
117        node = fused_graph.node_copy(root_node, load_arg)
118        args = list(node.args)
119        args.extend(extra_args)
120        node.args = tuple(args)
121        return node
122
123
124def _get_fusion_pattern_to_fuse_handler_cls(
125    backend_config: BackendConfig,
126) -> Dict[Pattern, Callable]:
127    fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
128    for pattern, config in backend_config._pattern_complex_format_to_config.items():
129        if config.fuser_method is not None:
130            # TODO: is this logic right?
131            fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
132    return fusion_pattern_to_fuse_handlers
133