1from typing import Any, Callable, Dict, List, Set, Tuple, Union 2 3import torch 4import torch.nn as nn 5import torch.nn.functional as F 6from torch.ao.quantization import FakeQuantizeBase, ObserverBase 7from torch.ao.quantization.backend_config import get_native_backend_config 8from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers 9from torch.ao.quantization.utils import getattr_from_fqn 10from torch.fx import GraphModule 11from torch.fx.graph import Node 12 13from .ns_types import NSNodeTargetType 14 15 16toq = torch.ops.quantized 17 18 19def get_type_a_related_to_b( 20 base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], 21) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]: 22 # TODO(future PR): allow customizations 23 # TODO(future PR): reuse existing quantization mappings 24 # TODO(future PR): add the rest of modules and ops here 25 type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set() 26 27 for s in base_name_to_sets_of_related_ops.values(): 28 s_list = list(s) 29 # add every bidirectional pair 30 for idx_0 in range(0, len(s_list)): 31 for idx_1 in range(idx_0, len(s_list)): 32 type_a_related_to_b.add((s_list[idx_0], s_list[idx_1])) 33 type_a_related_to_b.add((s_list[idx_1], s_list[idx_0])) 34 35 return type_a_related_to_b 36 37 38NSFusionElType = Union[ 39 Callable, # call_function or call_module type, example: F.linear or nn.Conv2d 40 str, # call_method name, example: "dequantize" 41 Tuple[ 42 str, Any 43 ], # call_method name and first argument, example: ("to", torch.float16) 44] 45NSFusionType = Union[ 46 Tuple[NSFusionElType, NSFusionElType], 47 Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType], 48] 49 50 51def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]: 52 """ 53 Set of potential fusions, in reverse order. The order is reversed 54 to match how fusion patterns are defined in quantization code. 55 56 Fusion format: 57 ((fusion_op_0, fusion_op_1), base_op_idx) 58 59 Where base_op_idx is the idx of the op we should use to match other related 60 ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx 61 of 0 represents the first op in regular (non-reverse) order, 1 represents the 62 second op, etc. 63 """ 64 results: List[Tuple[NSFusionType, int]] = [] 65 66 # Possible syntaxes: 67 # * single op: torch.nn.Conv2d 68 # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d) 69 # For fusions, we only care about patterns composed of multiple ops. 70 # TODO(future PR): allow customizations from default patterns. 71 all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config()) 72 73 default_base_op_idx = 0 74 for quant_pattern in all_quant_patterns.keys(): 75 # TODO: this is a temporary hack to flatten the patterns from quantization so 76 # that it works with the ns matcher function, maybe we should use `_is_match` 77 # in torch.ao.quantization.fx.match_utils to match the patterns 78 if ( 79 isinstance(quant_pattern, tuple) 80 and len(quant_pattern) == 2 81 and isinstance(quant_pattern[1], tuple) 82 and len(quant_pattern[1]) == 2 83 ): 84 # flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) 85 quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1]) 86 87 # Only patterns of multiple ops are fusions, ignore 88 # patterns which contain a single ops (they get matched 89 # without caring about fusions). 90 if isinstance(quant_pattern, tuple): 91 results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type] 92 93 # For each pattern, add additional patterns with observers and 94 # fake quants at the end. 95 # TODO(future PR): if needed, implement matching for a node 96 # having multiple output observers. 97 for cls in (ObserverBase, FakeQuantizeBase): 98 if isinstance(quant_pattern, tuple): 99 new_pattern = (cls, *quant_pattern) 100 else: 101 new_pattern = (cls, quant_pattern) 102 results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type] 103 104 # After this point, results contains values such as 105 # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...] 106 107 # Patterns for matching fp16 emulation are not specified in the quantization 108 # fusion mappings. For now, define them here. 109 fp16_em_base_op_idx = 1 110 patterns_to_add = [ 111 # linear-relu fp16 emulation: 112 # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16 113 ( 114 (("to", torch.float16), F.relu, F.linear, "dequantize"), 115 fp16_em_base_op_idx, 116 ), 117 # Conv-BN fusion (this happens outside of quantization patterns, 118 # which is why it is defined separately here). 119 ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), 120 ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), 121 ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), 122 ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), 123 ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), 124 ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), 125 ] 126 for p in patterns_to_add: 127 results.append(p) # type: ignore[arg-type] 128 results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type] 129 results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type] 130 131 return results 132 133 134def end_node_matches_reversed_fusion( 135 end_node: Node, 136 reversed_fusion: NSFusionType, 137 gm: GraphModule, 138 seen_nodes: Set[Node], 139) -> bool: 140 """ 141 Returns true if a pattern ending with `end_node` matches 142 the fusion pattern. 143 """ 144 cur_node = end_node 145 for fusion_idx in range(len(reversed_fusion)): 146 # each node can only belong to one matched pattern 147 if cur_node in seen_nodes: 148 return False 149 150 cur_fusion_el = reversed_fusion[fusion_idx] 151 152 if cur_node.op == "call_function": 153 fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and ( 154 not isinstance(cur_fusion_el, type) 155 ) 156 if fusion_el_is_fun: 157 if cur_node.target != cur_fusion_el: 158 return False 159 if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): 160 cur_node = cur_node.args[0] 161 else: 162 return False 163 else: 164 return False 165 166 elif cur_node.op == "call_module": 167 fusion_el_is_mod = isinstance(cur_fusion_el, type) 168 if fusion_el_is_mod: 169 assert isinstance(cur_node.target, str) 170 target_mod = getattr_from_fqn(gm, cur_node.target) 171 if not isinstance(cur_fusion_el, type): 172 return False 173 if not isinstance(target_mod, cur_fusion_el): 174 return False 175 if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): 176 cur_node = cur_node.args[0] 177 else: 178 return False 179 else: 180 return False 181 182 elif cur_node.op == "call_method": 183 fusion_el_is_meth_with_second_arg = ( 184 isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2 185 ) 186 fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str) 187 if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg: 188 if fusion_el_is_meth_without_args: 189 if cur_node.target != cur_fusion_el: 190 return False 191 else: 192 assert isinstance(cur_fusion_el, tuple) 193 if cur_node.target != cur_fusion_el[0]: 194 return False 195 elif len(cur_node.args) < 2: 196 return False 197 elif cur_node.args[1] != cur_fusion_el[1]: 198 return False 199 200 if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): 201 cur_node = cur_node.args[0] 202 else: 203 return False 204 else: 205 return False 206 else: 207 return False 208 209 return True 210