1# mypy: allow-untyped-defs 2from abc import ABC, abstractmethod 3from typing import Any, Callable, Dict, List, Union 4 5import torch 6from torch.ao.quantization.backend_config import BackendConfig 7from torch.ao.quantization.fuser_method_mappings import get_fuser_method_new 8from torch.ao.quantization.utils import _parent_name, NodePattern, Pattern 9from torch.fx.graph import Graph, Node 10from torch.nn.utils.parametrize import type_before_parametrizations 11 12from .custom_config import FuseCustomConfig 13from .match_utils import MatchAllNode 14 15 16__all__ = [ 17 "DefaultFuseHandler", 18 "FuseHandler", 19] 20 21 22# ---------------------------- 23# Fusion Pattern Registrations 24# ---------------------------- 25 26 27# Base Pattern Handler 28class FuseHandler(ABC): 29 """Base handler class for the fusion patterns""" 30 31 @abstractmethod 32 def __init__(self, node: Node): 33 pass 34 35 @abstractmethod 36 def fuse( 37 self, 38 load_arg: Callable, 39 named_modules: Dict[str, torch.nn.Module], 40 fused_graph: Graph, 41 root_node: Node, 42 extra_inputs: List[Any], 43 matched_node_pattern: NodePattern, 44 fuse_custom_config: FuseCustomConfig, 45 fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]], 46 is_qat: bool, 47 ) -> Node: 48 pass 49 50 51class DefaultFuseHandler(FuseHandler): 52 def __init__(self, node: Node): 53 super().__init__(node) # type:ignore[safe-super] 54 55 def fuse( 56 self, 57 load_arg: Callable, 58 named_modules: Dict[str, torch.nn.Module], 59 fused_graph: Graph, 60 root_node: Node, 61 extra_inputs: List[Any], 62 matched_node_pattern: NodePattern, 63 fuse_custom_config: FuseCustomConfig, 64 fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]], 65 is_qat: bool, 66 ) -> Node: 67 assert ( 68 root_node.op == "call_module" 69 ), "Expecting module node to be a call_module Node" 70 root_module = named_modules[str(root_node.target)] 71 72 def get_modules(pattern): 73 """Given a node pattern, extract the corresponding modules 74 e.g. input: (relu_node, (bn_node, conv_node)) 75 output: (relu_module, (bn_module, conv_module)) 76 """ 77 if isinstance(pattern, (tuple, list)): 78 n, *args = pattern 79 modules: List[torch.nn.Module] = [] 80 modules.append(get_modules(n)) 81 for a in args: 82 modules.append(get_modules(a)) 83 return tuple(modules) 84 else: 85 n = pattern 86 if n.op == "call_module": 87 return named_modules[n.target] 88 elif n.op == "call_function" and n.target == torch.nn.functional.relu: 89 relu = torch.nn.ReLU() 90 relu.training = root_module.training 91 return relu 92 elif n.op == "call_function" or n.op == "call_method": 93 return n.target 94 else: 95 return MatchAllNode 96 97 # since relu can be used multiple times, we'll need to create a relu module for each match 98 matched_modules = get_modules(matched_node_pattern) 99 100 def get_matched_types(m): 101 if isinstance(m, tuple): 102 return tuple(map(get_matched_types, m)) 103 if isinstance(m, torch.nn.Module): 104 return type_before_parametrizations(m) 105 return m 106 107 matched_module_types = get_matched_types(matched_modules) 108 module_parent_name, module_name = _parent_name(root_node.target) 109 fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping) 110 # TODO: change the signature for fuser_method to take matched module patterns 111 # as input 112 fused_module = fuser_method(is_qat, *matched_modules) 113 setattr(named_modules[module_parent_name], module_name, fused_module) 114 extra_args = [] 115 for input in extra_inputs: 116 extra_args.append(load_arg(input)) 117 node = fused_graph.node_copy(root_node, load_arg) 118 args = list(node.args) 119 args.extend(extra_args) 120 node.args = tuple(args) 121 return node 122 123 124def _get_fusion_pattern_to_fuse_handler_cls( 125 backend_config: BackendConfig, 126) -> Dict[Pattern, Callable]: 127 fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {} 128 for pattern, config in backend_config._pattern_complex_format_to_config.items(): 129 if config.fuser_method is not None: 130 # TODO: is this logic right? 131 fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler 132 return fusion_pattern_to_fuse_handlers 133