1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from itertools import chain 9from typing import cast, List, Optional, Tuple 10 11import torch 12from executorch.backends.xnnpack.partition.config.xnnpack_config import ( 13 ConfigPrecisionType, 14 XNNPartitionerConfig, 15) 16from executorch.backends.xnnpack.utils.quant_utils import ( 17 extract_qdq_affine_op_args_for_decomposed_ops, 18 is_affine_qdq, 19 is_dequant, 20 is_dynamic_qdq, 21 is_per_channel, 22 is_per_channel_group, 23 is_qparam, 24 is_quant, 25) 26from executorch.backends.xnnpack.utils.utils import ( 27 get_input_node, 28 is_getitem, 29 is_node, 30 is_param_node, 31) 32from executorch.exir.backend.canonical_partitioners.config_partitioner import ( 33 format_target_name, 34) 35from executorch.exir.backend.utils import WhyNoPartition 36from torch.export import ExportedProgram 37from torch.fx.passes.utils.source_matcher_utils import ( 38 get_source_partitions, 39 SourcePartition, 40) 41 42logger = logging.getLogger(__name__) 43why = WhyNoPartition(logger=logger) 44 45 46class GEMMConfig(XNNPartitionerConfig): 47 """ 48 GEMM-like ops like Convolution, Addmm, Linear, mostly behave in the same way, in which we 49 have some weight, bias, and activation node. The only difference between these types 50 of ops are that the weight, bias, and activations are in different indicies of the 51 nodes arguments, this class helps to generalize the logic needed to partition these 52 different ops 53 """ 54 55 def __init__(self, weight_idx, bias_idx, act_idx, fused_acts, **kwargs): 56 super().__init__(**kwargs) 57 self.weight_idx = weight_idx 58 self.bias_idx = bias_idx 59 self.act_idx = act_idx 60 self.fused_acts = fused_acts 61 62 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 63 if not self.check_common_constraints(node, ep): 64 # short circuit if we don't pass common constraints 65 return False 66 67 is_valid, _ = self.get_deps(node, ep) 68 if not is_valid: 69 why(node, "Failed to get valid dependent nodes.") 70 return is_valid 71 72 def get_node_and_deps( 73 self, node: torch.fx.Node, ep: ExportedProgram 74 ) -> List[torch.fx.Node]: 75 partition = [node] 76 _, deps = self.get_deps(node, ep) 77 partition.extend(deps) 78 79 return partition 80 81 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 82 return None 83 84 def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType: 85 weight = get_input_node(node, self.weight_idx) 86 87 if not is_dequant(weight): 88 return ConfigPrecisionType.FP32 89 90 activation = get_input_node(node, self.act_idx) 91 if is_dynamic_qdq(activation): 92 return ConfigPrecisionType.DYNAMIC_QUANT 93 94 return ConfigPrecisionType.STATIC_QUANT 95 96 def get_deps( 97 self, 98 node: torch.fx.Node, 99 ep: ExportedProgram, 100 ) -> Tuple[bool, List[torch.fx.Node]]: 101 """ 102 Gets all dependencies for this gemm partition. Returns a tuple of 103 a bool indicating if the deps are valid and a list of all the 104 dep nodes 105 """ 106 precision = self._detect_precision(node) 107 if precision not in self.supported_precision_types(): 108 # detected precision but it is either disabled or not supported 109 return (False, []) 110 111 valid_bias, bias_deps = self._get_bias_deps(node, ep, precision) 112 valid_weight, weight_deps = self._get_weight_deps(node, ep, precision) 113 valid_act, act_deps = self._get_act_deps(node, ep, precision) 114 valid_output, output_deps = self._get_output_deps(node, ep, precision) 115 116 valid_deps = valid_bias and valid_weight and valid_act and valid_output 117 deps = list(chain(bias_deps, weight_deps, act_deps, output_deps)) 118 119 return valid_deps, deps 120 121 def _get_weight_deps( 122 self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType 123 ) -> Tuple[bool, List[torch.fx.Node]]: 124 gemm_deps = [] 125 if precision == ConfigPrecisionType.FP32: 126 # First find the weight 127 weight_node = get_input_node(node, self.weight_idx) 128 if not is_param_node(ep, weight_node): 129 return (False, []) # weight must be a static param 130 gemm_deps.append(weight_node) 131 132 return (True, gemm_deps) 133 else: 134 # Quantized Weight deps 135 dequant_node = get_input_node(node, self.weight_idx) 136 if not is_dequant(dequant_node): 137 return False, [] 138 gemm_deps.append(dequant_node) 139 weight = get_input_node(dequant_node, 0) 140 if not is_param_node(ep, weight): 141 return False, [] 142 gemm_deps.append(weight) 143 144 if is_per_channel(dequant_node) or is_per_channel_group(dequant_node): 145 if len(dequant_node.all_input_nodes) < 2: 146 # Expected channel quantized to have scale/zp nodes 147 return False, [] 148 149 gemm_deps.extend(dequant_node.all_input_nodes[1:3]) 150 return (True, gemm_deps) 151 152 def _get_output_deps( 153 self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType 154 ) -> Tuple[bool, List[torch.fx.Node]]: 155 gemm_deps = [] 156 if precision == ConfigPrecisionType.STATIC_QUANT: 157 # Look for fused activations and tail end quant node 158 node_users = list(node.users.keys()) 159 if len(node_users) != 1: 160 # Expect quantized node to have a single output (fused act or dequant) 161 return False, [] 162 163 # Check if the quantized pattern has a fused activation 164 n_output = node_users[0] 165 if ( 166 n_output.op == "call_function" 167 and format_target_name(n_output.target.__name__) in self.fused_acts 168 ): 169 gemm_deps.append(n_output) 170 fused_out_users = list(n_output.users.keys()) 171 if len(fused_out_users) == 1: 172 n_output = fused_out_users[0] 173 174 if not is_quant(n_output): 175 # Expected gemm_node --> fused_act (optional) --> dequant 176 return (False, []) 177 gemm_deps.append(n_output) 178 elif precision == ConfigPrecisionType.FP32: 179 # Look for fused activations only, and partition with fp32 op 180 node_users = list(node.users.keys()) 181 if len(node_users) == 1: 182 n_output = node_users[0] 183 if ( 184 n_output.op == "call_function" 185 and format_target_name(n_output.target.__name__) in self.fused_acts 186 ): 187 gemm_deps.append(n_output) 188 189 # FP32 and Dynamic Quant have no output dependencies 190 return (True, gemm_deps) 191 192 def _get_bias_deps( 193 self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType 194 ) -> Tuple[bool, List[torch.fx.Node]]: 195 gemm_deps = [] 196 if len(node.all_input_nodes) > 2 and self.bias_idx: 197 bias_node = get_input_node(node, self.bias_idx) 198 if bias_node: 199 if not is_param_node(ep, bias_node): 200 return (False, []) # bias node must be a static param 201 gemm_deps.append(bias_node) 202 203 return (True, gemm_deps) 204 205 def _get_act_deps( 206 self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType 207 ) -> Tuple[bool, List[torch.fx.Node]]: 208 gemm_deps = [] 209 if precision == ConfigPrecisionType.FP32: 210 return (True, []) 211 else: 212 dq_input = get_input_node(node, self.act_idx) 213 if not is_dequant(dq_input): 214 # Expected static quant input to be dequant node 215 return False, [] 216 gemm_deps.append(dq_input) 217 if precision == ConfigPrecisionType.STATIC_QUANT: 218 # if static quant we are done after finding first dq_input 219 return (True, gemm_deps) 220 221 # q input node 222 q_input = get_input_node(dq_input, 0) 223 if not is_quant(q_input): 224 return (False, []) 225 226 gemm_deps.append(q_input) 227 q_input_args = q_input.args 228 if is_affine_qdq(q_input): 229 q_input_args = extract_qdq_affine_op_args_for_decomposed_ops(q_input) 230 if not (is_node(q_input_args[1]) and is_node(q_input_args[2])): 231 # expected to find getitem node from choose qparam 232 return (False, []) 233 234 getitem1 = q_input_args[1] 235 getitem2 = q_input_args[2] 236 237 if not (is_getitem(getitem1) and is_getitem(getitem2)): 238 # expected getitem node from choose qparam 239 return (False, []) 240 241 gemm_deps.extend([getitem1, getitem2]) 242 choose_qparam = get_input_node(getitem1, 0) 243 if not is_qparam(choose_qparam): 244 # expected to find choose_qparam node 245 return (False, []) 246 gemm_deps.append(choose_qparam) 247 return (True, gemm_deps) 248 249 250class LinearConfig(GEMMConfig): 251 target_name = "linear.default" 252 253 def __init__(self, **kwargs): 254 super().__init__( 255 weight_idx=1, 256 bias_idx=2, 257 act_idx=0, 258 fused_acts=["relu.default", "hardtanh.default"], 259 **kwargs, 260 ) 261 262 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 263 return torch.ops.aten.linear.default 264 265 def _get_weight_deps( 266 self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType 267 ) -> Tuple[bool, List[torch.fx.Node]]: 268 if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear: 269 # if force fp32_dynamic_linear is on and we detected this as fp32, then we 270 # do not partition the weight node 271 return (True, []) 272 273 return super()._get_weight_deps(node, ep, precision) 274 275 def supported_precision_types(self): 276 return [ 277 ConfigPrecisionType.DYNAMIC_QUANT, 278 ConfigPrecisionType.FP32, 279 ConfigPrecisionType.STATIC_QUANT, 280 ] 281 282 283class ConvolutionConfig(GEMMConfig): 284 target_name = "convolution.default" 285 286 def __init__(self, **kwargs): 287 super().__init__( 288 weight_idx=1, 289 bias_idx=2, 290 act_idx=0, 291 fused_acts=["relu.default", "hardtanh.default"], 292 **kwargs, 293 ) 294 295 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 296 """ 297 Currently we have no support for convolution 3d and transposed convolution 298 """ 299 if not super().check_constraints(node, ep): 300 return False 301 302 conv_stride = cast(List[int], node.args[3]) 303 if len(conv_stride) > 2: 304 why(node, "Only support 1D + 2D Conv") 305 return False # Only support 1D + 2D Conv 306 307 transposed = cast(bool, node.args[6]) 308 if transposed: 309 why(node, "Transposed Conv is not supported") 310 return False # Currently don't support transposed conv 311 312 return True 313 314 def supported_precision_types(self): 315 return [ 316 ConfigPrecisionType.FP32, 317 ConfigPrecisionType.STATIC_QUANT, 318 ] 319 320 321class AddmmConfig(GEMMConfig): 322 """ 323 We will handle the legacy form of addmm partitioning which will include 324 partitioning using source partitions. 325 """ 326 327 target_name = "addmm.default" 328 329 def __init__(self, **kwargs): 330 super().__init__( 331 weight_idx=2, 332 bias_idx=0, 333 act_idx=1, 334 fused_acts=["relu.default", "hardtanh.default"], 335 **kwargs, 336 ) 337 self.src_partitions = None 338 self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear] 339 340 def get_deps( 341 self, 342 node: torch.fx.Node, 343 ep: ExportedProgram, 344 ) -> Tuple[bool, List[torch.fx.Node]]: 345 """ 346 Gets all dependencies for this gemm partition. Returns a tuple of 347 a bool indicating if the deps are valid and a list of all the 348 dep nodes. This handles the src partition for 349 """ 350 if self.src_partitions is None: 351 # Cache src partitions so we don't have to recompute them every time 352 self.src_partitions = get_source_partitions(ep.graph, self.linear_modules) 353 354 # src_partition is None if node is not in source partition, 355 # otherwise gives us the linear source partition it belongs to 356 src_partition = None 357 for partition_list in self.src_partitions.values(): 358 for partition in partition_list: 359 if node in partition.nodes: 360 src_partition = partition 361 362 if src_partition: 363 # if addmm belongs to linear src partition, then partition the 364 # src partition and get its deps 365 return self.get_deps_from_src_partition(node, ep, src_partition) 366 367 return super().get_deps(node, ep) 368 369 def get_deps_from_src_partition( 370 self, node: torch.fx.Node, ep: ExportedProgram, src_partition: SourcePartition 371 ): 372 """ 373 Gets all the dependencies for the src partition. This is done by simulating the 374 linear node from the src partition. We find the associated weights, act, bias 375 from the linear src partition, and plug those in as the addmm node's args. We also 376 take the users of the src partitions output node as the addmm node's users. Finally 377 we just run the GEMMConfig's get_deps method no this faked linear node. After 378 getting the deps, we return the addmm nodes users and args back. 379 """ 380 381 def find_partition_args(input_node): 382 while ( 383 len(input_node.all_input_nodes) != 0 384 and input_node not in src_partition.input_nodes 385 ): 386 input_node = input_node.all_input_nodes[0] 387 return input_node 388 389 old_args, old_users = node.args, node.users 390 391 fake_args = [] 392 for arg in node.args: 393 # map addmm's args to the source partition's inputs 394 # basically simulating what the args of the linear node would be 395 fake_args.append(find_partition_args(arg)) 396 397 # validate source partition 398 if ( 399 # bias must be in source partition 400 (self.bias_idx and fake_args[self.bias_idx] not in src_partition.nodes) 401 # activation input must be an input node to partition 402 or fake_args[self.act_idx] not in src_partition.input_nodes 403 # weight can either be in the nodes or input_nodes 404 or fake_args[self.weight_idx] 405 not in (src_partition.nodes + src_partition.input_nodes) 406 # there can only be a single output node in partition 407 or len(src_partition.output_nodes) != 1 408 ): 409 return (False, []) 410 411 # map addmm's args to the source partition linear's inputs and users 412 node.args = tuple(fake_args) 413 node.users = src_partition.output_nodes[0].users 414 valid_deps, deps = super().get_deps(node, ep) 415 416 # Reset addmm node back to old args and users 417 node.args = old_args 418 node.users = old_users 419 420 return valid_deps, list(set(deps) | set(src_partition.nodes)) 421 422 def supported_precision_types(self): 423 return [ 424 ConfigPrecisionType.FP32, 425 ConfigPrecisionType.STATIC_QUANT, 426 ConfigPrecisionType.DYNAMIC_QUANT, 427 ] 428 429 430class MMConfig(AddmmConfig): 431 target_name = "mm.default" 432 433 def __init__(self, **kwargs): 434 super().__init__(**kwargs) 435 self.bias_idx = None 436 self.weight_idx = 1 437 self.act_idx = 0 438 439 def supported_precision_types(self): 440 return [ 441 ConfigPrecisionType.FP32, 442 ConfigPrecisionType.STATIC_QUANT, 443 ConfigPrecisionType.DYNAMIC_QUANT, 444 ] 445