1# Copyright (c) 2024 MediaTek Inc. 2# 3# Licensed under the BSD License (the "License"); you may not use this file 4# except in compliance with the License. See the license file in the root 5# directory of this source tree for more details. 6 7from typing import Callable, List 8 9import torch 10from torch._ops import OpOverload 11from torch._subclasses import FakeTensor 12 13from torch.ao.quantization.quantizer import QuantizationAnnotation 14from torch.ao.quantization.quantizer.utils import ( 15 _annotate_input_qspec_map, 16 _annotate_output_qspec, 17) 18 19from torch.export import export_for_training 20from torch.fx import Graph, Node 21from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( 22 SubgraphMatcherWithNameNodeMap, 23) 24 25from .qconfig import QuantizationConfig 26 27 28OP_TO_ANNOTATOR = {} 29 30 31def annotate(graph: Graph, quant_config: QuantizationConfig) -> None: 32 # Pattern annotation 33 _annotate_rmsnorm_pattern(graph, quant_config) 34 _annotate_fused_activation_pattern(graph, quant_config) 35 36 # Per-op annotation 37 for node in graph.nodes: 38 if node.op == "placeholder": 39 annotate_placeholder(node, quant_config) 40 elif node.op == "call_function": 41 annotate_func = OP_TO_ANNOTATOR.get(node.target, None) 42 if annotate_func is not None: 43 annotate_func(node, quant_config) 44 45 46def register_annotator(ops: List[OpOverload]): 47 48 def decorator(annotator_fn: Callable): 49 for op in ops: 50 OP_TO_ANNOTATOR[op] = annotator_fn 51 52 return decorator 53 54 55def _is_annotated(node: Node): 56 """ 57 Given a list of nodes (that represents an operator pattern), 58 return True if any of the node 59 is annotated, otherwise return False 60 """ 61 KEY = "quantization_annotation" 62 return KEY in node.meta and node.meta[KEY]._annotated 63 64 65def _mark_as_annotated(nodes: List[Node]): 66 KEY = "quantization_annotation" 67 for node in nodes: 68 if KEY not in node.meta: 69 node.meta[KEY] = QuantizationAnnotation() 70 node.meta[KEY]._annotated = True 71 72 73def _is_float_activation_tensor(node: Node): 74 if not isinstance(node, Node): 75 return False 76 if "val" not in node.meta: 77 return False 78 if not isinstance(node.meta["val"], FakeTensor): 79 return False 80 return node.meta["val"].dtype == torch.float32 81 82 83def _annotate_fused_activation_pattern( 84 graph: Graph, quant_config: QuantizationConfig 85) -> None: 86 for relu_node in graph.nodes: 87 # Check relu/relu6 node 88 if relu_node.op != "call_function": 89 continue 90 if relu_node.target not in [ 91 torch.ops.aten.relu.default, 92 torch.ops.aten.relu_.default, 93 torch.ops.aten.relu6.default, 94 ]: 95 continue 96 97 producer_node = relu_node.args[0] 98 if not isinstance(producer_node, Node): 99 continue 100 if producer_node.op != "call_function": 101 continue 102 if len(producer_node.users) != 1: 103 continue 104 105 # Handle affine + relu fusion 106 if producer_node.target in [ 107 torch.ops.aten.conv1d.default, 108 torch.ops.aten.conv2d.default, 109 torch.ops.aten.linear.default, 110 ]: 111 weight_node = producer_node.args[1] 112 _annotate_input_qspec_map( 113 producer_node, 114 weight_node, 115 quant_config.weight, 116 ) 117 _annotate_output_qspec(relu_node, quant_config.activation) 118 _mark_as_annotated([producer_node, weight_node, relu_node]) 119 continue 120 121 # Handle arithmetic + relu fusion 122 if producer_node.target in [ 123 torch.ops.aten.add.Scalar, 124 torch.ops.aten.add.Tensor, 125 torch.ops.aten.add_.Scalar, 126 torch.ops.aten.add_.Tensor, 127 torch.ops.aten.div.Scalar, 128 torch.ops.aten.div.Tensor, 129 torch.ops.aten.div_.Scalar, 130 torch.ops.aten.div_.Tensor, 131 torch.ops.aten.divide.Scalar, 132 torch.ops.aten.divide.Tensor, 133 torch.ops.aten.mul.Scalar, 134 torch.ops.aten.mul.Tensor, 135 torch.ops.aten.mul_.Scalar, 136 torch.ops.aten.mul_.Tensor, 137 torch.ops.aten.rsub.Scalar, 138 torch.ops.aten.rsub.Tensor, 139 torch.ops.aten.sub.Scalar, 140 torch.ops.aten.sub.Tensor, 141 torch.ops.aten.sub_.Scalar, 142 torch.ops.aten.sub_.Tensor, 143 ]: 144 _annotate_output_qspec(relu_node, quant_config.activation) 145 _mark_as_annotated([producer_node, relu_node]) 146 continue 147 148 149def _annotate_rmsnorm_pattern(graph: Graph, quant_config: QuantizationConfig) -> None: 150 151 class ExecuTorchPattern(torch.nn.Module): 152 def forward(self, x): 153 norm = x * torch.rsqrt((x * x).mean(-1, keepdim=True) + 1e-6) 154 return norm, {} 155 156 class MTKPattern(torch.nn.Module): 157 def forward(self, x): 158 norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6) 159 return norm, {} 160 161 for pattern_cls in (ExecuTorchPattern, MTKPattern): 162 pattern_gm = export_for_training(pattern_cls(), (torch.randn(3, 3),)).module() 163 matcher = SubgraphMatcherWithNameNodeMap( 164 pattern_gm, ignore_literals=True, remove_overlapping_matches=False 165 ) 166 matches = matcher.match(graph) 167 for match in matches: 168 target_nodes = [] 169 for node in match.nodes_map.values(): 170 if node in match.placeholder_nodes: 171 continue 172 if node.op == "call_function" and node.target in OP_TO_ANNOTATOR: 173 target_nodes.append(node) 174 175 if any(_is_annotated(node) for node in target_nodes): 176 continue 177 _mark_as_annotated(target_nodes) 178 for node in match.returning_nodes: 179 _annotate_output_qspec(node, quant_config.activation) 180 181 182def annotate_placeholder(node: Node, quant_config: QuantizationConfig) -> None: 183 if _is_annotated(node): 184 return 185 186 if _is_float_activation_tensor(node): 187 _annotate_output_qspec(node, quant_config.activation) 188 189 _mark_as_annotated([node]) 190 191 192@register_annotator( 193 [ 194 torch.ops.aten.conv1d.default, 195 torch.ops.aten.conv2d.default, 196 torch.ops.aten.linear.default, 197 ] 198) 199def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None: 200 if _is_annotated(node): 201 return 202 203 weight_node = node.args[1] 204 _annotate_input_qspec_map( 205 node, 206 weight_node, 207 quant_config.weight, 208 ) 209 _annotate_output_qspec(node, quant_config.activation) 210 211 # Make weight as annotated because it is a constant node 212 _mark_as_annotated([node, weight_node]) 213 214 215@register_annotator( 216 [ 217 torch.ops.aten.add.Scalar, 218 torch.ops.aten.add.Tensor, 219 torch.ops.aten.add_.Scalar, 220 torch.ops.aten.add_.Tensor, 221 torch.ops.aten.bmm.default, 222 torch.ops.aten.div.Scalar, 223 torch.ops.aten.div.Tensor, 224 torch.ops.aten.div_.Scalar, 225 torch.ops.aten.div_.Tensor, 226 torch.ops.aten.divide.Scalar, 227 torch.ops.aten.divide.Tensor, 228 torch.ops.aten.gelu.default, 229 torch.ops.aten.group_norm.default, 230 torch.ops.aten.layer_norm.default, 231 torch.ops.aten.leaky_relu.default, 232 torch.ops.aten.matmul.default, 233 torch.ops.aten.mul.Scalar, 234 torch.ops.aten.mul.Tensor, 235 torch.ops.aten.mul_.Scalar, 236 torch.ops.aten.mul_.Tensor, 237 torch.ops.aten.pow.Scalar, 238 torch.ops.aten.pow.Tensor_Scalar, 239 torch.ops.aten.pow.Tensor_Tensor, 240 torch.ops.aten.prelu.default, 241 torch.ops.aten.rsub.Scalar, 242 torch.ops.aten.rsub.Tensor, 243 torch.ops.aten.silu.default, 244 torch.ops.aten.sub.Scalar, 245 torch.ops.aten.sub.Tensor, 246 torch.ops.aten.sub_.Scalar, 247 torch.ops.aten.sub_.Tensor, 248 ] 249) 250def annotate_output_qspec(node: Node, quant_config: QuantizationConfig) -> None: 251 if _is_annotated(node): 252 return 253 _annotate_output_qspec(node, quant_config.activation) 254 _mark_as_annotated([node]) 255 256 257@register_annotator([torch.ops.aten.embedding.default]) 258def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None: 259 if _is_annotated(node): 260 return 261 262 wgt_node = node.args[0] 263 _annotate_input_qspec_map(node, wgt_node, quant_config.activation) 264 _mark_as_annotated([node]) 265