1# mypy: allow-untyped-defs 2import warnings 3from typing import Any, Callable, Dict, List, Tuple, Union 4 5from torch.ao.quantization.backend_config import ( 6 BackendConfig, 7 get_native_backend_config, 8) 9from torch.ao.quantization.backend_config.utils import ( 10 get_fuser_method_mapping, 11 get_fusion_pattern_to_extra_inputs_getter, 12 get_fusion_pattern_to_root_node_getter, 13) 14from torch.ao.quantization.utils import NodePattern, Pattern 15from torch.fx import GraphModule, map_arg, Node 16from torch.fx.graph import Graph 17 18from .custom_config import FuseCustomConfig 19from .fuse_handler import _get_fusion_pattern_to_fuse_handler_cls, FuseHandler 20from .match_utils import _is_match, MatchAllNode 21from .pattern_utils import _sorted_patterns_dict 22 23 24__all__ = [ 25 "fuse", 26 # TODO: We should make this private in the future 27 # This is currently needed for test_public_bindings for some reason 28 "FuseHandler", 29] 30 31 32def fuse( 33 model: GraphModule, 34 is_qat: bool, 35 fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None, 36 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 37) -> GraphModule: 38 if fuse_custom_config is None: 39 fuse_custom_config = FuseCustomConfig() 40 41 if isinstance(fuse_custom_config, dict): 42 warnings.warn( 43 "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " 44 "in a future version. Please pass in a FuseCustomConfig instead.", 45 FutureWarning, 46 stacklevel=2, 47 ) 48 fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) 49 50 if isinstance(backend_config, dict): 51 warnings.warn( 52 "Passing a backend_config_dict to prepare is deprecated and will not be supported " 53 "in a future version. Please pass in a BackendConfig instead.", 54 FutureWarning, 55 stacklevel=2, 56 ) 57 backend_config = BackendConfig.from_dict(backend_config) 58 59 named_modules = dict(model.named_modules()) 60 61 if backend_config is None: 62 backend_config = get_native_backend_config() 63 64 fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict( 65 _get_fusion_pattern_to_fuse_handler_cls(backend_config) 66 ) 67 fuser_method_mapping = get_fuser_method_mapping(backend_config) 68 fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter( 69 backend_config 70 ) 71 fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter( 72 backend_config 73 ) 74 75 # find fusion 76 fusion_pairs = _find_matches(model, model.graph, fusion_pattern_to_fuse_handler_cls) 77 # TODO: change this to inplace changes to graph, since we no longer construct 78 # new GraphModule anymore 79 fused_graph = Graph() 80 env: Dict[Any, Any] = {} 81 82 def load_arg(a): 83 return map_arg(a, lambda node: env[node.name]) 84 85 def default_root_node_getter(node_pattern): 86 while not isinstance(node_pattern[-1], Node): 87 node_pattern = node_pattern[-1] 88 return node_pattern[-1] 89 90 for node in model.graph.nodes: 91 ( 92 maybe_last_node, 93 pattern, 94 matched_node_pattern, 95 obj, 96 node_to_subpattern, 97 ) = fusion_pairs.get(node.name, (None, None, None, None, None)) 98 # get the corresponding subpattern for the current node 99 if node_to_subpattern is not None: 100 node_subpattern = node_to_subpattern.get(node, None) 101 else: 102 node_subpattern = None 103 if maybe_last_node is node: 104 assert obj is not None 105 root_node_getter = fusion_pattern_to_root_node_getter.get( 106 pattern, default_root_node_getter 107 ) 108 root_node = root_node_getter(matched_node_pattern) # type: ignore[index] 109 extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get( 110 pattern, None 111 ) 112 extra_inputs = [] 113 if extra_inputs_getter is not None: 114 extra_inputs = extra_inputs_getter(matched_node_pattern) 115 # TODO: add validation that root_node is a module and has the same type 116 # as the root_module in the configuration 117 env[node.name] = obj.fuse( 118 load_arg, 119 named_modules, 120 fused_graph, 121 root_node, 122 extra_inputs, 123 matched_node_pattern, # type: ignore[arg-type] 124 fuse_custom_config, 125 fuser_method_mapping, 126 is_qat, 127 ) 128 elif maybe_last_node is None or node_subpattern is MatchAllNode: 129 env[node.name] = fused_graph.node_copy(node, load_arg) 130 # node matched in patterns and is not root is removed here 131 132 model = GraphModule(model, fused_graph) 133 return model 134 135 136def _find_matches( 137 root: GraphModule, 138 graph: Graph, 139 pattern_to_fuse_handler_cls: Dict[Pattern, Callable], 140) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]: 141 modules = dict(root.named_modules()) 142 # node name -> (root_node, match_value) 143 match_map: Dict[ 144 str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]] 145 ] = {} 146 # a map from node to the matched subpattern 147 node_to_subpattern: Dict[Node, Any] = {} 148 149 # TODO: dedup with quantization matching function in match_utils.py 150 def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern): 151 if isinstance(pattern, tuple): 152 s, *args = pattern 153 current_node_pattern: List[Node] = [] 154 apply_match(s, node, match, current_node_pattern, node_to_subpattern) 155 for subpattern, arg in zip(args, node.args): 156 apply_match( 157 subpattern, arg, match, current_node_pattern, node_to_subpattern 158 ) 159 matched_node_pattern.append(tuple(current_node_pattern)) 160 else: 161 # the first pattern matches will take precedence 162 if node.name not in match_map: 163 matched_node_pattern.append(node) 164 # MatchAllNode here is actually MatchAllInputNode which should not 165 # be added to match_map 166 if pattern is not MatchAllNode: 167 node_to_subpattern[node] = pattern 168 root_node, pattern, handler = match 169 match_map[node.name] = ( 170 root_node, 171 pattern, 172 matched_node_pattern, 173 handler, 174 node_to_subpattern, 175 ) 176 177 for node in reversed(graph.nodes): 178 if node.name not in match_map: 179 for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items(): 180 matched_node_pattern: List[Node] = [] 181 if _is_match(modules, node, pattern): 182 apply_match( 183 pattern, 184 node, 185 (node, pattern, fuse_handler_cls(node)), 186 matched_node_pattern, 187 node_to_subpattern, 188 ) 189 break 190 191 return match_map 192