1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from abc import abstractmethod 9from enum import Enum 10from typing import List, Optional 11 12import torch 13from executorch.exir.backend.canonical_partitioners.config_partitioner import ( 14 format_target_name, 15 PartitionerConfig, 16) 17from executorch.exir.backend.utils import WhyNoPartition 18from torch.export import ExportedProgram 19 20logger = logging.getLogger(__name__) 21why = WhyNoPartition(logger=logger) 22 23 24class ConfigPrecisionType(Enum): 25 FP32 = 1 26 STATIC_QUANT = 2 27 DYNAMIC_QUANT = 3 28 29 30class XNNPartitionerConfig(PartitionerConfig): 31 """ 32 Base partitioner config for XNNPACK Partitioner Configs. Base wrapper class 33 for all XNNPACK Partitioner Configs allows us to apply control over 34 all PartitionerConfigs. XNNPACK Partitioner config also sets a property 35 for supported precision types. This allows partitioner configs to set 36 the precision types they support, and let users toggle which precision 37 types they want to enable 38 """ 39 40 def __init__(self, **kwargs): 41 super().__init__() 42 self.enabled_precision_types = self.supported_precision_types() 43 # Flag used in GEMMConfig() 44 self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False) 45 46 def get_partition( 47 self, node: torch.fx.Node, ep: ExportedProgram 48 ) -> List[torch.fx.Node]: 49 """ 50 Overriding abstract method get_partition. 51 52 Returns the partitioned nodes from get_node_and_deps, but also labels them 53 with the name of the XNNPartitionerConfig class which return this set of nodes. 54 This enforces that all partitions returned by XNNPartitioner configs are labeled 55 with the partitioner config which returned them 56 """ 57 partitioned_nodes = self.get_node_and_deps(node, ep) 58 # label partitioned nodes with the name of the partitioner config 59 for node in partitioned_nodes: 60 if "xnn_partitioner_config" in node.meta: 61 node.meta["xnn_partitioner_config"].append(self.__class__.__name__) 62 else: 63 node.meta["xnn_partitioner_config"] = [self.__class__.__name__] 64 65 return partitioned_nodes 66 67 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 68 # By default if not specified, we do not halt decomposition for those configs 69 return None 70 71 @abstractmethod 72 def supported_precision_types(self) -> List[ConfigPrecisionType]: 73 """ 74 Returns the supported PrecisionType of this partitioner config 75 """ 76 pass 77 78 @abstractmethod 79 def get_node_and_deps( 80 self, node: torch.fx.Node, ep: ExportedProgram 81 ) -> List[torch.fx.Node]: 82 """ 83 Takes in a node and its exported program and returns a list of nodes 84 and its dependencies that need to be partitioned together 85 86 Args: 87 node: Node to be partitioned 88 ep: Exported program of the graph module 89 Returns: 90 List of nodes that can be partitioned 91 """ 92 pass 93 94 def set_enabled_precision_types( 95 self, precision_types: Optional[List[ConfigPrecisionType]] 96 ): 97 """ 98 Set the enabled precisions. 99 100 We take the intersection of the precision_types we wish to enable with 101 the precision types that this config supports. If enabled_precisions is empty, i.e. 102 the config does not support any of the precision types we want to enable, 103 then we will not partition nothing and return false at the common constraints 104 """ 105 106 if precision_types: 107 enabled_precisions = [] 108 for precision in precision_types: 109 if precision in self.supported_precision_types(): 110 enabled_precisions.append(precision) 111 112 self.enabled_precision_types = enabled_precisions 113 114 def check_common_constraints( 115 self, node: torch.fx.Node, ep: ExportedProgram 116 ) -> bool: 117 """ 118 Checks common xnnpack constraints 119 120 Args: 121 node (torch.fx.Node): Node to check common constraints against 122 ep (ExportedProgram): Exported Program to check constraints against 123 124 Returns: 125 True or False whether this node is partitionable 126 """ 127 assert ( 128 node.op == "call_function" 129 and format_target_name(node.target.__name__) # pyre-ignore 130 == self.target_name 131 ) 132 133 if len(self.enabled_precision_types) == 0: 134 why(node, reason="not enabled precision types") 135 return False 136 137 has_valid_dtypes = self._check_node_has_valid_dtype(node) 138 if not has_valid_dtypes: 139 why(node, reason="invalid dtype") 140 return False 141 142 return True 143 144 def _check_inputs_are_valid_dtypes(self, node, valid_dtypes): 145 # Check inputs are valid dtypes 146 # Gather all args which are nodes 147 args_to_check = [] 148 for arg in node.args: 149 if isinstance(arg, list) or isinstance(arg, tuple): 150 for item in arg: 151 if isinstance(item, torch.fx.Node): 152 args_to_check.append(item) 153 154 if isinstance(arg, torch.fx.Node): 155 args_to_check.append(arg) 156 157 for arg in args_to_check: 158 arg_val = arg.meta.get("val", None) 159 160 if arg_val is None or isinstance(arg_val, tuple): 161 continue 162 163 # Being conservative for now, UX >> Perf 164 # TODO: We need a pass to scrub these out. 165 if not isinstance(arg_val, torch.Tensor): 166 return False 167 168 # XNNPACK does not support empty tensors 169 if arg_val.numel() == 0: 170 return False 171 172 if arg_val.dtype not in valid_dtypes: 173 return False 174 175 return True 176 177 def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): 178 # Check outputs are valid dtype 179 node_val = node.meta.get("val", None) 180 if node_val is None: 181 return True 182 183 if not isinstance(node_val, tuple): 184 node_val = (node_val,) 185 186 for val in node_val: 187 if not isinstance(val, torch.Tensor): 188 return False 189 190 if val.dtype not in valid_dtypes: 191 return False 192 193 return True 194 195 def _check_node_has_valid_dtype(self, node): 196 valid_dtypes = { 197 torch.float32, 198 torch.float16, 199 torch.int8, 200 torch.qint8, 201 } 202 if ( 203 node.op != "placeholder" 204 and node.op != "call_function" 205 and node.op != "get_attr" 206 ): 207 return False 208 209 return self._check_inputs_are_valid_dtypes( 210 node, valid_dtypes 211 ) and self._check_outputs_are_valid_dtypes(node, valid_dtypes) 212